Skip to content

Commit 5de71d3

Browse files
authored
NNlib: handle kernel flip with reverse instead of window_reversal (#115)
* NNlib: handle kernel flip with reverse instead of window_reversal * add conv flip test
1 parent d6e5866 commit 5de71d3

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,20 @@ function NNlib.conv(
9393
)
9494
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
9595

96+
weight = W.mlir_data
97+
if !flipkernel
98+
weight = Reactant.MLIR.IR.result(
99+
Reactant.MLIR.Dialects.stablehlo.reverse(
100+
weight; dimensions=collect(kernel_spatial_dims .- 1)
101+
),
102+
)
103+
end
104+
96105
conv = Reactant.MLIR.Dialects.stablehlo.convolution(
97106
x.mlir_data,
98-
W.mlir_data;
107+
weight;
99108
result_0=result_type,
100109
window_strides=collect(stride),
101-
window_reversal=collect(fill(flipkernel, num_spatial_dims)),
102110
padding,
103111
dimension_numbers,
104112
lhs_dilation=1,

test/nn.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,22 @@ mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far!
8888
@test res_reactant nn_conv(img)
8989
end
9090

91+
@testset "conv 1d: flip" begin
92+
x = [1; 2; 3;;;]
93+
W = [1; 2; 3;;;]
94+
95+
xx = Reactant.ConcreteRArray(x)
96+
WW = Reactant.ConcreteRArray(W)
97+
98+
conv_noflip(x, W) = NNlib.conv(x, W; pad=1, flipped=true)
99+
conv_flip(x, W) = NNlib.conv(x, W; pad=1, flipped=false)
100+
101+
@test Reactant.compile(conv_noflip, (xx, WW))(xx, WW) ==
102+
[0*1+1*2+2*3; 1*1+2*2+3*3; 1*2+2*3+3*0;;;]
103+
@test Reactant.compile(conv_flip, (xx, WW))(xx, WW) ==
104+
[3*0+2*1+1*2; 3*1+2*2+1*3; 3*2+2*3+1*0;;;]
105+
end
106+
91107
@testset "$f" for f in (NNlib.meanpool, NNlib.maxpool)
92108
img = randn(Float32, 224, 224, 3, 2)
93109
img_reactant = Reactant.ConcreteRArray(img)

0 commit comments

Comments
 (0)