Skip to content

Commit 29aff8e

Browse files
committed
test: conv and pool with more parameters
1 parent ac3c512 commit 29aff8e

File tree

3 files changed

+58
-55
lines changed

3 files changed

+58
-55
lines changed

test/nn/flux.jl

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -67,50 +67,3 @@ optim # parameters, momenta and output have all changed
6767
out2 = model(noisy) # first row is prob. of true, second row p(false)
6868

6969
mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far!
70-
71-
@testset "conv: groups $groups" for groups in (1, 2, 4)
72-
nn_conv = Conv(
73-
randn(Float32, 10, 10, 8 ÷ groups, groups), randn(Float32, groups); groups
74-
)
75-
conv_reactant = Conv(
76-
Reactant.ConcreteRArray(nn_conv.weight),
77-
Reactant.ConcreteRArray(nn_conv.bias);
78-
groups,
79-
)
80-
81-
img = randn(Float32, 224, 224, 8, 2)
82-
img_reactant = Reactant.ConcreteRArray(img)
83-
84-
comp_conv = @compile conv_reactant(img_reactant)
85-
86-
res_reactant = Array{Float32,4}(comp_conv(img_reactant))
87-
88-
@test res_reactant nn_conv(img)
89-
end
90-
91-
@testset "conv 1d: flip" begin
92-
x = [1; 2; 3;;;]
93-
W = [1; 2; 3;;;]
94-
95-
xx = Reactant.ConcreteRArray(x)
96-
WW = Reactant.ConcreteRArray(W)
97-
98-
conv_noflip(x, W) = NNlib.conv(x, W; pad=1, flipped=true)
99-
conv_flip(x, W) = NNlib.conv(x, W; pad=1, flipped=false)
100-
101-
@test @compile(conv_noflip(xx, WW))(xx, WW) ==
102-
[0*1+1*2+2*3; 1*1+2*2+3*3; 1*2+2*3+3*0;;;]
103-
@test @compile(conv_flip(xx, WW))(xx, WW) == [3*0+2*1+1*2; 3*1+2*2+1*3; 3*2+2*3+1*0;;;]
104-
end
105-
106-
@testset "$f" for f in (NNlib.meanpool, NNlib.maxpool)
107-
img = randn(Float32, 224, 224, 3, 2)
108-
img_reactant = Reactant.ConcreteRArray(img)
109-
110-
f_reactant = @compile f(img_reactant, (3, 3))
111-
112-
res_reactant = f_reactant(img_reactant, (3, 3))
113-
res = f(img, (3, 3))
114-
115-
@test res_reactant res
116-
end

test/nn/luxlib.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
using LuxLib, Reactant, Enzyme
22

3-
@testset "Fused Dense" begin
4-
end
3+
@testset "Fused Dense" begin end
54

6-
@testset "Bias Activation" begin
7-
end
5+
@testset "Bias Activation" begin end
86

9-
@testset "Fast Activation" begin
10-
end
7+
@testset "Fast Activation" begin end
118

12-
@testset "Fused Conv" begin
13-
end
9+
@testset "Fused Conv" begin end

test/nn/nnlib.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,61 @@ using NNlib, Reactant, Enzyme
3232
end
3333

3434
@testset "Pooling" begin
35+
@testset for f in (NNlib.meanpool, NNlib.maxpool)
36+
x = randn(Float32, 32, 32, 3, 2)
37+
x_reactant = Reactant.ConcreteRArray(x)
38+
39+
@testset for window in ((2, 2), (3, 3), (4, 4)),
40+
stride in ((1, 1), (2, 2)),
41+
padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0))
42+
43+
pool_dims = PoolDims(x, window; stride, padding)
44+
45+
f_reactant = Reactant.compile(f, (x_reactant, pool_dims))
46+
47+
broken = any(==(2), padding) && f === NNlib.maxpool && window == (2, 2)
48+
@test f_reactant(x_reactant, pool_dims) f(x, pool_dims) broken = broken
49+
end
50+
51+
# TODO: test for gradients
52+
end
3553
end
3654

3755
@testset "Convolution" begin
56+
@testset for groups in (1, 2, 4)
57+
weight = randn(Float32, 10, 10, 8 ÷ groups, groups)
58+
x = randn(Float32, 16, 16, 8, 2)
59+
60+
weight_reactant = Reactant.ConcreteRArray(weight)
61+
x_reactant = Reactant.ConcreteRArray(x)
62+
63+
@testset for stride in ((1, 1), (2, 2), (3, 3)), padding in ((0, 0), (1, 1), (2, 2))
64+
conv_dims = DenseConvDims(x, weight; stride, padding, dilation=1, groups)
65+
66+
conv_compiled = Reactant.compile(
67+
NNlib.conv, (x_reactant, weight_reactant, conv_dims)
68+
)
69+
70+
@test conv_compiled(x_reactant, weight_reactant, conv_dims)
71+
NNlib.conv(x, weight, conv_dims)
72+
end
73+
74+
# TODO: test for gradients
75+
end
76+
77+
@testset "conv 1d: flip" begin
78+
x = [1; 2; 3;;;]
79+
W = [1; 2; 3;;;]
80+
81+
xx = Reactant.ConcreteRArray(x)
82+
WW = Reactant.ConcreteRArray(W)
83+
84+
conv_noflip(x, W) = NNlib.conv(x, W; pad=1, flipped=true)
85+
conv_flip(x, W) = NNlib.conv(x, W; pad=1, flipped=false)
86+
87+
@test Reactant.compile(conv_noflip, (xx, WW))(xx, WW) ==
88+
[0*1+1*2+2*3; 1*1+2*2+3*3; 1*2+2*3+3*0;;;]
89+
@test Reactant.compile(conv_flip, (xx, WW))(xx, WW) ==
90+
[3*0+2*1+1*2; 3*1+2*2+1*3; 3*2+2*3+1*0;;;]
91+
end
3892
end

0 commit comments

Comments
 (0)