Skip to content

Commit f1547db

Browse files
committed
perf: use dot_general to implement batched matrix multiply
1 parent 24b63fd commit f1547db

File tree

2 files changed

+69
-7
lines changed

2 files changed

+69
-7
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module ReactantNNlibExt
22

33
using NNlib
4-
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array
4+
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR
55

66
for (jlop, hloop) in (
77
(:(NNlib.tanh_fast), :tanh),
@@ -214,13 +214,43 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
214214
),
215215
)
216216
end
217-
if size(x, 3) == size(y, 3)
218-
return cat([x[:, :, i] * y[:, :, i] for i in axes(x, 3)]...; dims=Val(3))
219-
elseif size(x, 3) == 1
220-
return cat([x[:, :, i] * y[:, :, 1] for i in axes(x, 3)]...; dims=Val(3))
221-
elseif size(y, 3) == 1
222-
return cat([x[:, :, 1] * y[:, :, i] for i in axes(y, 3)]...; dims=Val(3))
217+
x = permutedims(x, (3, 1, 2))
218+
y = permutedims(y, (3, 1, 2))
219+
220+
B = max(size(x, 1), size(y, 1))
221+
out_shape = (B, size(x, 2), size(y, 3))
222+
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(x.mlir_data)))
223+
224+
if size(x, 1) != size(y, 1)
225+
if size(x, 1) == 1
226+
x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
227+
elseif size(y, 1) == 1
228+
y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
229+
end
223230
end
231+
232+
dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
233+
MLIR.IR.context(), 1, [0], 1, [0], 1, [2], 1, [1]
234+
)
235+
236+
prec = MLIR.IR.Attribute(
237+
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
238+
)
239+
res = TracedRArray{T,3}(
240+
(),
241+
MLIR.IR.result(
242+
MLIR.Dialects.stablehlo.dot_general(
243+
x.mlir_data,
244+
y.mlir_data;
245+
result_0=resty,
246+
dot_dimension_numbers=dot_dimension_numbers,
247+
precision_config=prec,
248+
),
249+
1,
250+
),
251+
size(resty),
252+
)
253+
return permutedims(res, (2, 3, 1))
224254
end
225255

226256
end # module ReactantNNlibExt

test/nn/nnlib.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,35 @@ end
9191
[3*0+2*1+1*2; 3*1+2*2+1*3; 3*2+2*3+1*0;;;]
9292
end
9393
end
94+
95+
@testset "Batched Matrix Multiplication" begin
96+
x = rand(Float32, 4, 3, 5)
97+
y = rand(Float32, 3, 2, 5)
98+
99+
x_ra = Reactant.ConcreteRArray(x)
100+
y_ra = Reactant.ConcreteRArray(y)
101+
102+
bmm_compiled = @compile batched_mul(x_ra, y_ra)
103+
104+
@test bmm_compiled(x_ra, y_ra) batched_mul(x, y)
105+
106+
x = rand(Float32, 4, 3, 1)
107+
y = rand(Float32, 3, 2, 5)
108+
109+
x_ra = Reactant.ConcreteRArray(x)
110+
y_ra = Reactant.ConcreteRArray(y)
111+
112+
bmm_compiled = @compile batched_mul(x_ra, y_ra)
113+
114+
@test bmm_compiled(x_ra, y_ra) batched_mul(x, y)
115+
116+
x = rand(Float32, 4, 3, 5)
117+
y = rand(Float32, 3, 2, 1)
118+
119+
x_ra = Reactant.ConcreteRArray(x)
120+
y_ra = Reactant.ConcreteRArray(y)
121+
122+
bmm_compiled = @compile batched_mul(x_ra, y_ra)
123+
124+
@test bmm_compiled(x_ra, y_ra) batched_mul(x, y)
125+
end

0 commit comments

Comments
 (0)