Skip to content

Commit 7361409

Browse files
committed
test: fused_convolution tests
1 parent 723dae1 commit 7361409

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

test/nn/luxlib.jl

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ using LuxLib, Reactant, Enzyme, NNlib
22

33
@testset "Fused Dense" begin end
44

5-
@testset "Bias Activation" begin end
5+
@testset "Bias Activation" begin
6+
end
67

78
@testset "Fast Activation" begin
89
# Here we are testing that fast_activation doesn't switch to the faster versions
@@ -55,4 +56,38 @@ using LuxLib, Reactant, Enzyme, NNlib
5556
end
5657
end
5758

58-
@testset "Fused Conv" begin end
59+
@testset "Fused Conv" begin
60+
@testset for groups in (1, 2, 4),
61+
has_bias in (true, false),
62+
act in (identity, relu, sigmoid, tanh, gelu)
63+
64+
weight = randn(Float32, 4, 4, 8 ÷ groups, 4)
65+
x = randn(Float32, 16, 16, 8, 2)
66+
bias = has_bias ? randn(Float32, 4) : nothing
67+
68+
weight_reactant = Reactant.ConcreteRArray(weight)
69+
x_reactant = Reactant.ConcreteRArray(x)
70+
bias_reactant = Reactant.to_rarray(bias)
71+
72+
@testset for stride in ((1, 1), (2, 2), (3, 3)),
73+
padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0), (0, 1), (1, 0)),
74+
dilation in ((1, 1), (2, 2), (1, 2), (2, 1))
75+
76+
conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)
77+
78+
fused_conv_compiled = Reactant.compile(
79+
fused_conv_bias_activation,
80+
(act, weight_reactant, x_reactant, bias_reactant, conv_dims),
81+
)
82+
83+
reactant_res = fused_conv_compiled(
84+
act, weight_reactant, x_reactant, bias_reactant, conv_dims
85+
)
86+
luxlib_res = fused_conv_bias_activation(act, weight, x, bias, conv_dims)
87+
88+
@test reactant_res luxlib_res broken = (act === gelu)
89+
end
90+
91+
# TODO: test for gradients
92+
end
93+
end

test/nn/nnlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ end
5656

5757
@testset "Convolution" begin
5858
@testset for groups in (1, 2, 4)
59-
weight = randn(Float32, 4, 4, 8 ÷ groups, groups)
59+
weight = randn(Float32, 4, 4, 8 ÷ groups, 4)
6060
x = randn(Float32, 16, 16, 8, 2)
6161

6262
weight_reactant = Reactant.ConcreteRArray(weight)

0 commit comments

Comments
 (0)