@@ -3,7 +3,15 @@ module ReactantNNlibExt
33using NNlib
44using GPUArraysCore: @allowscalar
55using Reactant:
6- Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
6+ Reactant,
7+ Ops,
8+ TracedRArray,
9+ AnyTracedRArray,
10+ materialize_traced_array,
11+ MLIR,
12+ TracedRNumber,
13+ get_mlir_data,
14+ set_mlir_data!
715using ReactantCore: @trace
816using LinearAlgebra: LinearAlgebra, triu
917
@@ -12,14 +20,7 @@ for (jlop, hloop) in (
1220 (:(NNlib. sigmoid_fast), :logistic ),
1321 (:(NNlib. sigmoid), :logistic ),
1422)
15- @eval function $ (jlop)(x:: TracedRNumber{T} ) where {T}
16- return TracedRNumber {T} (
17- (),
18- Reactant. MLIR. IR. result (
19- Reactant. MLIR. Dialects. stablehlo.$ (hloop)(x. mlir_data), 1
20- ),
21- )
22- end
23+ @eval $ (jlop)(x:: TracedRNumber ) = Ops.$ (hloop)(x)
2324end
2425
2526function NNlib. softmax! (out:: TracedRArray{T,N} , x:: AbstractArray ; dims= 1 ) where {T,N}
@@ -82,13 +83,6 @@ function NNlib.conv!(
8283 kernel_input_dim = N - 1
8384 kernel_output_dim = N
8485
85- output_spatial_shapes = map (input_spatial_dims) do i
86- K = kernel_size[i]
87- pl, pr = padding[2 i - 1 ], padding[2 i]
88- d = dilation[i]
89- s = stride[i]
90- return (size (x, i) + pl + pr - d * (K - 1 ) - 1 ) ÷ s + 1
91- end
9286 output_batch_dim = input_batch_dim
9387 output_feature_dim = input_feature_dim
9488 output_spatial_dims = input_spatial_dims
@@ -119,8 +113,8 @@ function NNlib.conv!(
119113 end
120114
121115 conv = Reactant. MLIR. Dialects. stablehlo. convolution (
122- x . mlir_data ,
123- weight. mlir_data ;
116+ get_mlir_data (x) ,
117+ get_mlir_data ( weight) ;
124118 result_0= result_type,
125119 window_strides= collect (stride),
126120 padding,
@@ -130,7 +124,7 @@ function NNlib.conv!(
130124 feature_group_count,
131125 batch_group_count= 1 ,
132126 )
133- y . mlir_data = Reactant. MLIR. IR. result (conv)
127+ set_mlir_data! (y, Reactant. MLIR. IR. result (conv) )
134128 return y
135129end
136130
@@ -165,7 +159,9 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
165159 output_shape = (output_spatial_shapes... , size (x, N - 1 ), size (x, N))
166160 result_type = Reactant. MLIR. IR. TensorType (output_shape, Reactant. MLIR. IR. Type (T))
167161
168- unranked = Reactant. MLIR. IR. TensorType ((), eltype (Reactant. MLIR. IR. type (x. mlir_data)))
162+ unranked = Reactant. MLIR. IR. TensorType (
163+ (), eltype (Reactant. MLIR. IR. type (get_mlir_data (x)))
164+ )
169165 body =
170166 let body = Reactant. MLIR. IR. Region (),
171167 loc = Reactant. MLIR. IR. Location (),
@@ -189,7 +185,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
189185 Reactant. MLIR. Dialects. stablehlo. constant (; value= attr)
190186 )
191187 reduction = Reactant. MLIR. Dialects. stablehlo. reduce_window (
192- [x . mlir_data ],
188+ [get_mlir_data (x) ],
193189 [init_value];
194190 result_0= [result_type],
195191 window_dimensions,
@@ -205,24 +201,24 @@ end
205201function NNlib. maxpool! (
206202 y:: TracedRArray{T} , x:: AnyTracedRArray , pdims:: NNlib.PoolDims
207203) where {T}
208- y . mlir_data =
209- reduce_window (
210- Reactant . MLIR . Dialects . stablehlo . maximum, T .(x), pdims; init = typemin (T )
211- ) . mlir_data
204+ res = reduce_window (
205+ Reactant . MLIR . Dialects . stablehlo . maximum, T .(x), pdims; init = typemin (T)
206+ )
207+ set_mlir_data! (y, get_mlir_data (res))
212208 return y
213209end
214210
215211function NNlib. meanpool! (
216212 y:: TracedRArray{T} , x:: AnyTracedRArray , pdims:: NNlib.PoolDims
217213) where {T}
218214 res = reduce_window (Reactant. MLIR. Dialects. stablehlo. add, T .(x), pdims; init= zero (T))
219- y . mlir_data = (res ./ T (prod (NNlib. kernel_size (pdims)))). mlir_data
215+ set_mlir_data! (y, get_mlir_data (res ./ T (prod (NNlib. kernel_size (pdims)))))
220216 return y
221217end
222218
223- NNlib. batched_transpose (x:: AnyTracedRArray{T,3} ) where {T} = permutedims (x, (2 , 1 , 3 ))
219+ NNlib. batched_transpose (x:: AnyTracedRArray{T,3} ) where {T} = PermutedDimsArray (x, (2 , 1 , 3 ))
224220function NNlib. batched_adjoint (x:: AnyTracedRArray{T,3} ) where {T}
225- y = permutedims (x, ( 2 , 1 , 3 ) )
221+ y = NNlib . batched_transpose (x )
226222 conj! (y)
227223 return y
228224end
@@ -238,64 +234,47 @@ function NNlib.batched_mul!(
238234 ),
239235 )
240236 end
237+
238+ if size (x, 3 ) != size (y, 3 )
239+ B = max (size (x, 3 ), size (y, 3 ))
240+ if size (x, 3 ) == 1
241+ x = Reactant. broadcast_to_size (x, (size (x, 1 ), size (x, 2 ), B))
242+ elseif size (y, 3 ) == 1
243+ y = Reactant. broadcast_to_size (y, (size (y, 1 ), size (y, 2 ), B))
244+ end
245+ end
246+
241247 x = permutedims (x, (3 , 1 , 2 ))
242248 y = permutedims (y, (3 , 1 , 2 ))
243249
244- B = max (size (x, 1 ), size (y, 1 ))
245- out_shape = (B, size (x, 2 ), size (y, 3 ))
246- resty = MLIR. IR. TensorType (out_shape, eltype (MLIR. IR. type (res. mlir_data)))
247-
248250 if size (x, 1 ) != size (y, 1 )
251+ B = max (size (x, 1 ), size (y, 1 ))
249252 if size (x, 1 ) == 1
250253 x = Reactant. broadcast_to_size (x, (B, size (x, 2 ), size (x, 3 )))
251254 elseif size (y, 1 ) == 1
252255 y = Reactant. broadcast_to_size (y, (B, size (y, 2 ), size (y, 3 )))
253256 end
254257 end
255258
256- dot_dimension_numbers = MLIR. API. stablehloDotDimensionNumbersGet (
257- MLIR. IR. context (), 1 , [0 ], 1 , [0 ], 1 , [2 ], 1 , [1 ]
259+ tmp = Ops. dot_general (
260+ T1 .(materialize_traced_array (x)),
261+ T1 .(materialize_traced_array (y));
262+ contracting_dimensions= ([3 ], [2 ]),
263+ batching_dimensions= ([1 ], [1 ]),
258264 )
265+ set_mlir_data! (res, get_mlir_data (permutedims (tmp, (2 , 3 , 1 ))))
259266
260- prec = MLIR. IR. Attribute (
261- MLIR. API. stablehloPrecisionAttrGet (MLIR. IR. context (), " DEFAULT" )
262- )
263- tmp = TracedRArray {T1,3} (
264- (),
265- MLIR. IR. result (
266- MLIR. Dialects. stablehlo. dot_general (
267- x. mlir_data,
268- y. mlir_data;
269- result_0= resty,
270- dot_dimension_numbers= dot_dimension_numbers,
271- precision_config= prec,
272- ),
273- 1 ,
274- ),
275- size (resty),
276- )
277- res. mlir_data = permutedims (tmp, (2 , 3 , 1 )). mlir_data
278267 return res
279268end
280269
281270function NNlib. pad_constant (
282- x:: TracedRArray {T,N} , pad:: NTuple{N,Tuple{Int,Int}} , value
271+ x:: AnyTracedRArray {T,N} , pad:: NTuple{N,Tuple{Int,Int}} , value
283272) where {T,N}
284273 value = Reactant. promote_to (TracedRNumber{T}, value)
285- edge_padding_low = [i[1 ] for i in pad]
286- edge_padding_high = [i[2 ] for i in pad]
287- interior_padding = [0 for i in pad]
288- res = MLIR. IR. result (
289- MLIR. Dialects. stablehlo. pad (
290- x. mlir_data,
291- value. mlir_data;
292- edge_padding_low,
293- edge_padding_high,
294- interior_padding,
295- ),
296- 1 ,
297- )
298- return TracedRArray {T,N} ((), res, size (MLIR. IR. type (res)))
274+ low = [i[1 ] for i in pad]
275+ high = [i[2 ] for i in pad]
276+ interior = [0 for i in pad]
277+ return Ops. pad (materialize_traced_array (x), value; low, high, interior)
299278end
300279
301280# XXX : reevaluate this manual optimization once
@@ -305,7 +284,7 @@ function NNlib.gather!(
305284 src:: AnyTracedRArray{T2,2} ,
306285 idxs:: Union{AbstractUnitRange{<:Number}} ,
307286) where {T1,T2}
308- dst. mlir_data = src[:, idxs]. mlir_data
287+ set_mlir_data! ( dst, get_mlir_data ( src[:, idxs]))
309288 return dst
310289end
311290
@@ -314,8 +293,8 @@ function NNlib.gather!(
314293) where {T1,T2}
315294 dims = NNlib. scatter_dims (src, dst, idxs)
316295 @assert dims == 1 # scatter_dims lets us do some size checks so we call that function
317- idxs = (Reactant. promote_to (TracedRArray{Int,1 }, idxs) .- 1 ). mlir_data
318- slice_sizes = Reactant. promote_to (TracedRArray{Int,1 }, [size (src, 1 ), 1 ]). mlir_data
296+ idxs = get_mlir_data (Reactant. promote_to (TracedRArray{Int,1 }, idxs) .- 1 )
297+ slice_sizes = get_mlir_data ( Reactant. promote_to (TracedRArray{Int,1 }, [size (src, 1 ), 1 ]))
319298
320299 # ! format: off
321300 dimension_numbers = MLIR. API. stablehloGatherDimensionNumbersGet (
@@ -331,11 +310,11 @@ function NNlib.gather!(
331310
332311 res = MLIR. IR. result (
333312 Reactant. MLIR. Dialects. stablehlo. dynamic_gather (
334- src. mlir_data , idxs, slice_sizes; dimension_numbers
313+ get_mlir_data ( src) , idxs, slice_sizes; dimension_numbers
335314 ),
336315 1 ,
337316 )
338- dst. mlir_data = res
317+ set_mlir_data! ( dst, res)
339318 return dst
340319end
341320
@@ -354,7 +333,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
354333 return reshape (res, start_sizes... , :)
355334 end
356335 res = reshape (cat (results... ; dims= (dims + 1 )), size (dst))
357- dst. mlir_data = res. mlir_data
336+ set_mlir_data! ( dst, get_mlir_data ( res))
358337 return dst
359338end
360339
@@ -363,7 +342,7 @@ dilate_shape(s, d) = max(0, 1 + d * (s - 1))
363342# see lax._conv_general_dilated_transpose_rhs
364343# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495
365344function NNlib. ∇conv_filter! (
366- dw:: Reactant. TracedRArray{T,N} ,
345+ dw:: TracedRArray{T,N} ,
367346 x:: AnyTracedRArray ,
368347 dy:: AnyTracedRArray ,
369348 cdims:: NNlib.DenseConvDims ,
@@ -437,8 +416,8 @@ function NNlib.∇conv_filter!(
437416
438417 result_type = Reactant. MLIR. IR. TensorType (size (dw), Reactant. MLIR. IR. Type (T))
439418 conv = MLIR. Dialects. stablehlo. convolution (
440- x . mlir_data ,
441- dy . mlir_data ;
419+ get_mlir_data (x) ,
420+ get_mlir_data (dy) ;
442421 result_0= result_type,
443422 window_strides= collect (dilation),
444423 padding,
@@ -447,11 +426,12 @@ function NNlib.∇conv_filter!(
447426 feature_group_count,
448427 batch_group_count,
449428 )
450-
451- dw. mlir_data = MLIR. IR. result (conv)
429+ set_mlir_data! (dw, MLIR. IR. result (conv))
452430
453431 if ! NNlib. flipkernel (cdims)
454- dw. mlir_data = Reactant. Ops. reverse (dw; dimensions= output_spatial_dims). mlir_data
432+ set_mlir_data! (
433+ dw, get_mlir_data (Reactant. Ops. reverse (dw; dimensions= output_spatial_dims))
434+ )
455435 end
456436
457437 return dw
@@ -553,8 +533,8 @@ function NNlib.∇conv_data!(
553533 end
554534
555535 conv = MLIR. Dialects. stablehlo. convolution (
556- dy . mlir_data ,
557- w . mlir_data ;
536+ get_mlir_data (dy) ,
537+ get_mlir_data (w) ;
558538 result_0= result_type,
559539 window_strides= 1 ,
560540 padding,
@@ -564,8 +544,7 @@ function NNlib.∇conv_data!(
564544 feature_group_count,
565545 batch_group_count= 1 ,
566546 )
567-
568- dx. mlir_data = MLIR. IR. result (conv)
547+ set_mlir_data! (dx, MLIR. IR. result (conv))
569548
570549 return dx
571550end
0 commit comments