Skip to content

Commit c9fd63c

Browse files
Pangorawglou-nes
authored andcommitted
Implement NNlib.∇conv_data! and NNlib.∇conv_filter! (EnzymeAD#318)
* implement NNlib.∇conv_data and NNlib.∇conv_filter * cond filter flipkernel
1 parent 1be04f7 commit c9fd63c

File tree

3 files changed

+260
-10
lines changed

3 files changed

+260
-10
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 215 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,14 @@ function NNlib.conv!(
113113
)
114114
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))
115115

116-
weight = W.mlir_data
116+
weight = W
117117
if !flipkernel
118-
weight = Reactant.MLIR.IR.result(
119-
Reactant.MLIR.Dialects.stablehlo.reverse(
120-
weight; dimensions=collect(kernel_spatial_dims .- 1)
121-
),
122-
)
118+
weight = Reactant.Ops.reverse(weight; dimensions=kernel_spatial_dims)
123119
end
124120

125121
conv = Reactant.MLIR.Dialects.stablehlo.convolution(
126122
x.mlir_data,
127-
weight;
123+
weight.mlir_data;
128124
result_0=result_type,
129125
window_strides=collect(stride),
130126
padding,
@@ -377,4 +373,216 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
377373
return dst
378374
end
379375

376+
dilate_shape(s, d) = max(0, 1 + d * (s - 1))
377+
378+
# see lax._conv_general_dilated_transpose_rhs
379+
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495
380+
function NNlib.∇conv_filter!(
381+
dw::Reactant.TracedRArray{T,N},
382+
x::AnyTracedRArray,
383+
dy::AnyTracedRArray,
384+
cdims::NNlib.DenseConvDims,
385+
) where {T,N}
386+
# (w, h, cin, b)
387+
# (w, h, cout, b)
388+
# -> (w, h, cin, cout)
389+
390+
x = T.(materialize_traced_array(x))
391+
dy = T.(materialize_traced_array(dy))
392+
393+
num_spatial_dims = N - 2
394+
input_batch_dim = N - 1
395+
input_feature_dim = N
396+
397+
kernel_input_dim = N
398+
kernel_output_dim = N - 1
399+
400+
output_batch_dim = N - 1
401+
output_feature_dim = N
402+
403+
output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1:num_spatial_dims
404+
405+
padding = reshape(collect(NNlib.padding(cdims)), (2, num_spatial_dims))
406+
stride = NNlib.stride(cdims)
407+
dilation = NNlib.dilation(cdims)
408+
feature_group_count = NNlib.groupcount(cdims)
409+
410+
padding =
411+
let lhs_shape = first(size(x), num_spatial_dims),
412+
rhs_shape = dilate_shape.(first(size(dw), num_spatial_dims), dilation),
413+
out_shape = dilate_shape.(first(size(dy), num_spatial_dims), stride),
414+
415+
padding = reduce(
416+
hcat,
417+
(
418+
let pad_before = padding[1, i],
419+
pad_after = (
420+
out_shape[i] - lhs_shape[i] + rhs_shape[i] - pad_before - 1
421+
)
422+
423+
[pad_before, pad_after]
424+
end for i in 1:num_spatial_dims
425+
),
426+
)
427+
428+
Reactant.MLIR.IR.DenseElementsAttribute(padding')
429+
end
430+
431+
batch_group_count = 1
432+
if feature_group_count > 1
433+
batch_group_count = feature_group_count
434+
feature_group_count = 1
435+
end
436+
437+
dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
438+
MLIR.IR.context(),
439+
Int64(input_batch_dim - 1),
440+
Int64(input_feature_dim - 1),
441+
length(input_spatial_dims),
442+
Int64[i - 1 for i in input_spatial_dims],
443+
Int64(kernel_input_dim - 1),
444+
Int64(kernel_output_dim - 1),
445+
length(kernel_spatial_dims),
446+
Int64[i - 1 for i in kernel_spatial_dims],
447+
Int64(output_batch_dim - 1),
448+
Int64(output_feature_dim - 1),
449+
length(output_spatial_dims),
450+
Int64[i - 1 for i in output_spatial_dims],
451+
)
452+
453+
result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T))
454+
conv = MLIR.Dialects.stablehlo.convolution(
455+
x.mlir_data,
456+
dy.mlir_data;
457+
result_0=result_type,
458+
window_strides=collect(dilation),
459+
padding,
460+
dimension_numbers,
461+
rhs_dilation=collect(stride),
462+
feature_group_count,
463+
batch_group_count,
464+
)
465+
466+
dw.mlir_data = MLIR.IR.result(conv)
467+
468+
if !NNlib.flipkernel(cdims)
469+
dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data
470+
end
471+
472+
return dw
473+
end
474+
475+
# see lax._conv_general_dilated_transpose_lhs
476+
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L457
477+
function NNlib.∇conv_data!(
478+
dx::Reactant.TracedRArray{T,N},
479+
dy::AnyTracedRArray,
480+
w::AnyTracedRArray,
481+
cdims::NNlib.DenseConvDims,
482+
) where {T,N}
483+
# (w, h, cout, b)
484+
# (w, h, cin, cout)
485+
# -> (w, h, cin, b)
486+
487+
dy = T.(materialize_traced_array(dy))
488+
w = T.(materialize_traced_array(w))
489+
490+
num_spatial_dims = N - 2
491+
input_batch_dim = N
492+
input_feature_dim = N - 1
493+
494+
kernel_input_dim = N
495+
kernel_output_dim = N - 1
496+
497+
output_batch_dim = N
498+
output_feature_dim = N - 1
499+
500+
output_spatial_dims = kernel_spatial_dims = input_spatial_dims = 1:num_spatial_dims
501+
502+
padding = reshape(collect(NNlib.padding(cdims)), (2, num_spatial_dims))
503+
stride = NNlib.stride(cdims)
504+
dilation = NNlib.dilation(cdims)
505+
feature_group_count = NNlib.groupcount(cdims)
506+
507+
# jax does
508+
# (cout, cin, h, w) -> (group, cout ÷ group, cin , h, w) -> (cout ÷ group, group, cin, h, w) -> (cout, cin * group, h, w)
509+
# we perform the same operation but in transposed form
510+
# (w, h, cin, cout) -> (w, h, cin, cout ÷ group, group) -> (w, h, cin, group, cout ÷ group) -> (w, h, cin * group, cout ÷ group)
511+
if feature_group_count > 1
512+
w = reshape(
513+
w,
514+
(size(w, i) for i in kernel_spatial_dims)...,
515+
size(w, N - 1),
516+
size(w, N) ÷ feature_group_count,
517+
feature_group_count,
518+
)
519+
w = permutedims(w, (kernel_spatial_dims..., N - 1, N + 1, N))
520+
w = reshape(
521+
w,
522+
(size(w, i) for i in kernel_spatial_dims)...,
523+
size(w, N - 1) * feature_group_count,
524+
size(w, N + 1),
525+
)
526+
end
527+
528+
padding =
529+
let lhs_shape = first(size(dx), num_spatial_dims),
530+
rhs_shape = dilate_shape.(first(size(w), num_spatial_dims), dilation),
531+
out_shape = dilate_shape.(first(size(dy), num_spatial_dims), stride),
532+
533+
padding = reduce(
534+
hcat,
535+
(
536+
let pad_before = rhs_shape[i] - padding[2i - 1] - 1,
537+
pad_after =
538+
lhs_shape[i] + rhs_shape[i] - 1 - out_shape[i] - pad_before
539+
540+
[pad_before, pad_after]
541+
end for i in input_spatial_dims
542+
),
543+
)
544+
545+
Reactant.MLIR.IR.DenseElementsAttribute(padding')
546+
end
547+
548+
dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
549+
MLIR.IR.context(),
550+
Int64(input_batch_dim - 1),
551+
Int64(input_feature_dim - 1),
552+
length(input_spatial_dims),
553+
Int64[i - 1 for i in input_spatial_dims],
554+
Int64(kernel_input_dim - 1),
555+
Int64(kernel_output_dim - 1),
556+
length(kernel_spatial_dims),
557+
Int64[i - 1 for i in kernel_spatial_dims],
558+
Int64(output_batch_dim - 1),
559+
Int64(output_feature_dim - 1),
560+
length(output_spatial_dims),
561+
Int64[i - 1 for i in output_spatial_dims],
562+
)
563+
564+
result_type = Reactant.MLIR.IR.TensorType(size(dx), Reactant.MLIR.IR.Type(T))
565+
566+
if NNlib.flipkernel(cdims)
567+
w = Reactant.Ops.reverse(w; dimensions=kernel_spatial_dims)
568+
end
569+
570+
conv = MLIR.Dialects.stablehlo.convolution(
571+
dy.mlir_data,
572+
w.mlir_data;
573+
result_0=result_type,
574+
window_strides=1,
575+
padding,
576+
lhs_dilation=collect(stride),
577+
rhs_dilation=collect(dilation),
578+
dimension_numbers,
579+
feature_group_count,
580+
batch_group_count=1,
581+
)
582+
583+
dx.mlir_data = MLIR.IR.result(conv)
584+
585+
return dx
586+
end
587+
380588
end # module ReactantNNlibExt

src/Ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ function reverse(
815815
stablehlo.reverse(
816816
x.mlir_data;
817817
result=mlir_type(TracedRArray{T,N}, size(x)),
818-
dimensions=MLIR.IR.DenseArrayAttribute(dimensions .- 1),
818+
dimensions=MLIR.IR.DenseArrayAttribute(collect(dimensions .- 1)),
819819
location,
820820
),
821821
)

test/nn/nnlib.jl

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,26 @@ end
6767

6868
conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)
6969

70+
output_size = (
71+
NNlib.output_size(conv_dims)...,
72+
size(weight, ndims(weight)),
73+
size(x, ndims(x)),
74+
)
75+
dy = randn(Float32, output_size)
76+
dy_reactant = Reactant.to_rarray(dy)
77+
7078
conv_compiled = Reactant.compile(
7179
NNlib.conv, (x_reactant, weight_reactant, conv_dims)
7280
)
7381

7482
@test conv_compiled(x_reactant, weight_reactant, conv_dims)
7583
NNlib.conv(x, weight, conv_dims)
76-
end
7784

78-
# TODO: test for gradients
85+
@test Reactant.@jit(NNlib.∇conv_data(dy_reactant, weight_reactant, conv_dims))
86+
NNlib.∇conv_data(dy, weight, conv_dims)
87+
@test Reactant.@jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims))
88+
NNlib.∇conv_filter(x, dy, conv_dims)
89+
end
7990
end
8091

8192
@testset "conv 1d: flip" begin
@@ -351,3 +362,34 @@ end
351362
@test size(y) == (size(src)[1:(Nsrc - M)]..., size(index)...)
352363
end
353364
end
365+
366+
@testset "∇conv(D = $ndim)" for ndim in 1:3
367+
x_spatial_dim = 4
368+
batch_size = 2
369+
n_in_features = 3
370+
n_out_features = 4
371+
kernel_size = Tuple((2 for _ in 1:ndim))
372+
373+
x = randn(Float32, (x_spatial_dim for _ in 1:ndim)..., n_in_features, batch_size)
374+
x_reactant = Reactant.to_rarray(x)
375+
376+
w = randn(Float32, kernel_size..., n_in_features, n_out_features)
377+
w_reactant = Reactant.to_rarray(w)
378+
379+
@testset "conv: padding=$padding stride=$stride dilation=$dilation groups=$groups" for (
380+
padding, stride, dilation, groups
381+
) in Iterators.product(
382+
(0, 2), (1, 2), (1,), (1,)
383+
)
384+
conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups)
385+
386+
output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size)
387+
dy = randn(Float32, output_size)
388+
dy_reactant = Reactant.to_rarray(dy)
389+
390+
@test Reactant.@jit(NNlib.∇conv_data(dy_reactant, w_reactant, conv_dims))
391+
NNlib.∇conv_data(dy, w, conv_dims)
392+
@test Reactant.@jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims))
393+
NNlib.∇conv_filter(x, dy, conv_dims)
394+
end
395+
end

0 commit comments

Comments
 (0)