|
1 | 1 | module ReactantNNlibExt
|
2 | 2 |
|
3 | 3 | using NNlib
|
4 |
| -using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array |
| 4 | +using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR |
5 | 5 |
|
6 | 6 | for (jlop, hloop) in (
|
7 | 7 | (:(NNlib.tanh_fast), :tanh),
|
@@ -214,13 +214,43 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
|
214 | 214 | ),
|
215 | 215 | )
|
216 | 216 | end
|
217 |
| - if size(x, 3) == size(y, 3) |
218 |
| - return cat([x[:, :, i] * y[:, :, i] for i in axes(x, 3)]...; dims=Val(3)) |
219 |
| - elseif size(x, 3) == 1 |
220 |
| - return cat([x[:, :, i] * y[:, :, 1] for i in axes(x, 3)]...; dims=Val(3)) |
221 |
| - elseif size(y, 3) == 1 |
222 |
| - return cat([x[:, :, 1] * y[:, :, i] for i in axes(y, 3)]...; dims=Val(3)) |
| 217 | + x = permutedims(x, (3, 1, 2)) |
| 218 | + y = permutedims(y, (3, 1, 2)) |
| 219 | + |
| 220 | + B = max(size(x, 1), size(y, 1)) |
| 221 | + out_shape = (B, size(x, 2), size(y, 3)) |
| 222 | + resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(x.mlir_data))) |
| 223 | + |
| 224 | + if size(x, 1) != size(y, 1) |
| 225 | + if size(x, 1) == 1 |
| 226 | + x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) |
| 227 | + elseif size(y, 1) == 1 |
| 228 | + y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) |
| 229 | + end |
223 | 230 | end
|
| 231 | + |
| 232 | + dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( |
| 233 | + MLIR.IR.context(), 1, [0], 1, [0], 1, [2], 1, [1] |
| 234 | + ) |
| 235 | + |
| 236 | + prec = MLIR.IR.Attribute( |
| 237 | + MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") |
| 238 | + ) |
| 239 | + res = TracedRArray{T,3}( |
| 240 | + (), |
| 241 | + MLIR.IR.result( |
| 242 | + MLIR.Dialects.stablehlo.dot_general( |
| 243 | + x.mlir_data, |
| 244 | + y.mlir_data; |
| 245 | + result_0=resty, |
| 246 | + dot_dimension_numbers=dot_dimension_numbers, |
| 247 | + precision_config=prec, |
| 248 | + ), |
| 249 | + 1, |
| 250 | + ), |
| 251 | + size(resty), |
| 252 | + ) |
| 253 | + return permutedims(res, (2, 3, 1)) |
224 | 254 | end
|
225 | 255 |
|
226 | 256 | end # module ReactantNNlibExt
|
0 commit comments