@@ -113,18 +113,14 @@ function NNlib.conv!(
113113 )
114114 result_type = Reactant. MLIR. IR. TensorType (size (y), Reactant. MLIR. IR. Type (T))
115115
116- weight = W. mlir_data
116+ weight = W
117117 if ! flipkernel
118- weight = Reactant. MLIR. IR. result (
119- Reactant. MLIR. Dialects. stablehlo. reverse (
120- weight; dimensions= collect (kernel_spatial_dims .- 1 )
121- ),
122- )
118+ weight = Reactant. Ops. reverse (weight; dimensions= kernel_spatial_dims)
123119 end
124120
125121 conv = Reactant. MLIR. Dialects. stablehlo. convolution (
126122 x. mlir_data,
127- weight;
123+ weight. mlir_data ;
128124 result_0= result_type,
129125 window_strides= collect (stride),
130126 padding,
@@ -377,4 +373,216 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
377373 return dst
378374end
379375
376+ dilate_shape (s, d) = max (0 , 1 + d * (s - 1 ))
377+
378+ # see lax._conv_general_dilated_transpose_rhs
379+ # https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495
380+ function NNlib. ∇conv_filter! (
381+ dw:: Reactant.TracedRArray{T,N} ,
382+ x:: AnyTracedRArray ,
383+ dy:: AnyTracedRArray ,
384+ cdims:: NNlib.DenseConvDims ,
385+ ) where {T,N}
386+ # (w, h, cin, b)
387+ # (w, h, cout, b)
388+ # -> (w, h, cin, cout)
389+
390+ x = T .(materialize_traced_array (x))
391+ dy = T .(materialize_traced_array (dy))
392+
393+ num_spatial_dims = N - 2
394+ input_batch_dim = N - 1
395+ input_feature_dim = N
396+
397+ kernel_input_dim = N
398+ kernel_output_dim = N - 1
399+
400+ output_batch_dim = N - 1
401+ output_feature_dim = N
402+
403+ output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1 : num_spatial_dims
404+
405+ padding = reshape (collect (NNlib. padding (cdims)), (2 , num_spatial_dims))
406+ stride = NNlib. stride (cdims)
407+ dilation = NNlib. dilation (cdims)
408+ feature_group_count = NNlib. groupcount (cdims)
409+
410+ padding =
411+ let lhs_shape = first (size (x), num_spatial_dims),
412+ rhs_shape = dilate_shape .(first (size (dw), num_spatial_dims), dilation),
413+ out_shape = dilate_shape .(first (size (dy), num_spatial_dims), stride),
414+
415+ padding = reduce (
416+ hcat,
417+ (
418+ let pad_before = padding[1 , i],
419+ pad_after = (
420+ out_shape[i] - lhs_shape[i] + rhs_shape[i] - pad_before - 1
421+ )
422+
423+ [pad_before, pad_after]
424+ end for i in 1 : num_spatial_dims
425+ ),
426+ )
427+
428+ Reactant. MLIR. IR. DenseElementsAttribute (padding' )
429+ end
430+
431+ batch_group_count = 1
432+ if feature_group_count > 1
433+ batch_group_count = feature_group_count
434+ feature_group_count = 1
435+ end
436+
437+ dimension_numbers = MLIR. API. stablehloConvDimensionNumbersGet (
438+ MLIR. IR. context (),
439+ Int64 (input_batch_dim - 1 ),
440+ Int64 (input_feature_dim - 1 ),
441+ length (input_spatial_dims),
442+ Int64[i - 1 for i in input_spatial_dims],
443+ Int64 (kernel_input_dim - 1 ),
444+ Int64 (kernel_output_dim - 1 ),
445+ length (kernel_spatial_dims),
446+ Int64[i - 1 for i in kernel_spatial_dims],
447+ Int64 (output_batch_dim - 1 ),
448+ Int64 (output_feature_dim - 1 ),
449+ length (output_spatial_dims),
450+ Int64[i - 1 for i in output_spatial_dims],
451+ )
452+
453+ result_type = Reactant. MLIR. IR. TensorType (size (dw), Reactant. MLIR. IR. Type (T))
454+ conv = MLIR. Dialects. stablehlo. convolution (
455+ x. mlir_data,
456+ dy. mlir_data;
457+ result_0= result_type,
458+ window_strides= collect (dilation),
459+ padding,
460+ dimension_numbers,
461+ rhs_dilation= collect (stride),
462+ feature_group_count,
463+ batch_group_count,
464+ )
465+
466+ dw. mlir_data = MLIR. IR. result (conv)
467+
468+ if ! NNlib. flipkernel (cdims)
469+ dw. mlir_data = Reactant. Ops. reverse (dw; dimensions= output_spatial_dims). mlir_data
470+ end
471+
472+ return dw
473+ end
474+
475+ # see lax._conv_general_dilated_transpose_lhs
476+ # https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L457
477+ function NNlib. ∇conv_data! (
478+ dx:: Reactant.TracedRArray{T,N} ,
479+ dy:: AnyTracedRArray ,
480+ w:: AnyTracedRArray ,
481+ cdims:: NNlib.DenseConvDims ,
482+ ) where {T,N}
483+ # (w, h, cout, b)
484+ # (w, h, cin, cout)
485+ # -> (w, h, cin, b)
486+
487+ dy = T .(materialize_traced_array (dy))
488+ w = T .(materialize_traced_array (w))
489+
490+ num_spatial_dims = N - 2
491+ input_batch_dim = N
492+ input_feature_dim = N - 1
493+
494+ kernel_input_dim = N
495+ kernel_output_dim = N - 1
496+
497+ output_batch_dim = N
498+ output_feature_dim = N - 1
499+
500+ output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1 : num_spatial_dims
501+
502+ padding = reshape (collect (NNlib. padding (cdims)), (2 , num_spatial_dims))
503+ stride = NNlib. stride (cdims)
504+ dilation = NNlib. dilation (cdims)
505+ feature_group_count = NNlib. groupcount (cdims)
506+
507+ # jax does
508+ # (cout, cin, h, w) -> (group, cout ÷ group, cin , h, w) -> (cout ÷ group, group, cin, h, w) -> (cout, cin * group, h, w)
509+ # we perform the same operation but in transposed form
510+ # (w, h, cin, cout) -> (w, h, cin, cout ÷ group, group) -> (w, h, cin, group, cout ÷ group) -> (w, h, cin * group, cout ÷ group)
511+ if feature_group_count > 1
512+ w = reshape (
513+ w,
514+ (size (w, i) for i in kernel_spatial_dims). .. ,
515+ size (w, N - 1 ),
516+ size (w, N) ÷ feature_group_count,
517+ feature_group_count,
518+ )
519+ w = permutedims (w, (kernel_spatial_dims... , N - 1 , N + 1 , N))
520+ w = reshape (
521+ w,
522+ (size (w, i) for i in kernel_spatial_dims). .. ,
523+ size (w, N - 1 ) * feature_group_count,
524+ size (w, N + 1 ),
525+ )
526+ end
527+
528+ padding =
529+ let lhs_shape = first (size (dx), num_spatial_dims),
530+ rhs_shape = dilate_shape .(first (size (w), num_spatial_dims), dilation),
531+ out_shape = dilate_shape .(first (size (dy), num_spatial_dims), stride),
532+
533+ padding = reduce (
534+ hcat,
535+ (
536+ let pad_before = rhs_shape[i] - padding[2 i - 1 ] - 1 ,
537+ pad_after =
538+ lhs_shape[i] + rhs_shape[i] - 1 - out_shape[i] - pad_before
539+
540+ [pad_before, pad_after]
541+ end for i in input_spatial_dims
542+ ),
543+ )
544+
545+ Reactant. MLIR. IR. DenseElementsAttribute (padding' )
546+ end
547+
548+ dimension_numbers = MLIR. API. stablehloConvDimensionNumbersGet (
549+ MLIR. IR. context (),
550+ Int64 (input_batch_dim - 1 ),
551+ Int64 (input_feature_dim - 1 ),
552+ length (input_spatial_dims),
553+ Int64[i - 1 for i in input_spatial_dims],
554+ Int64 (kernel_input_dim - 1 ),
555+ Int64 (kernel_output_dim - 1 ),
556+ length (kernel_spatial_dims),
557+ Int64[i - 1 for i in kernel_spatial_dims],
558+ Int64 (output_batch_dim - 1 ),
559+ Int64 (output_feature_dim - 1 ),
560+ length (output_spatial_dims),
561+ Int64[i - 1 for i in output_spatial_dims],
562+ )
563+
564+ result_type = Reactant. MLIR. IR. TensorType (size (dx), Reactant. MLIR. IR. Type (T))
565+
566+ if NNlib. flipkernel (cdims)
567+ w = Reactant. Ops. reverse (w; dimensions= kernel_spatial_dims)
568+ end
569+
570+ conv = MLIR. Dialects. stablehlo. convolution (
571+ dy. mlir_data,
572+ w. mlir_data;
573+ result_0= result_type,
574+ window_strides= 1 ,
575+ padding,
576+ lhs_dilation= collect (stride),
577+ rhs_dilation= collect (dilation),
578+ dimension_numbers,
579+ feature_group_count,
580+ batch_group_count= 1 ,
581+ )
582+
583+ dx. mlir_data = MLIR. IR. result (conv)
584+
585+ return dx
586+ end
587+
380588end # module ReactantNNlibExt
0 commit comments