@@ -2,7 +2,8 @@ using LuxLib, Reactant, Enzyme, NNlib
2
2
3
3
@testset " Fused Dense" begin end
4
4
5
- @testset " Bias Activation" begin end
5
+ @testset " Bias Activation" begin
6
+ end
6
7
7
8
@testset " Fast Activation" begin
8
9
# Here we are testing that fast_activation doesn't switch to the faster versions
@@ -55,4 +56,38 @@ using LuxLib, Reactant, Enzyme, NNlib
55
56
end
56
57
end
57
58
58
- @testset " Fused Conv" begin end
59
+ @testset " Fused Conv" begin
60
+ @testset for groups in (1 , 2 , 4 ),
61
+ has_bias in (true , false ),
62
+ act in (identity, relu, sigmoid, tanh, gelu)
63
+
64
+ weight = randn (Float32, 4 , 4 , 8 ÷ groups, 4 )
65
+ x = randn (Float32, 16 , 16 , 8 , 2 )
66
+ bias = has_bias ? randn (Float32, 4 ) : nothing
67
+
68
+ weight_reactant = Reactant. ConcreteRArray (weight)
69
+ x_reactant = Reactant. ConcreteRArray (x)
70
+ bias_reactant = Reactant. to_rarray (bias)
71
+
72
+ @testset for stride in ((1 , 1 ), (2 , 2 ), (3 , 3 )),
73
+ padding in ((0 , 0 ), (1 , 1 ), (2 , 2 ), (0 , 2 ), (2 , 0 ), (0 , 1 ), (1 , 0 )),
74
+ dilation in ((1 , 1 ), (2 , 2 ), (1 , 2 ), (2 , 1 ))
75
+
76
+ conv_dims = DenseConvDims (x, weight; stride, padding, dilation, groups)
77
+
78
+ fused_conv_compiled = Reactant. compile (
79
+ fused_conv_bias_activation,
80
+ (act, weight_reactant, x_reactant, bias_reactant, conv_dims),
81
+ )
82
+
83
+ reactant_res = fused_conv_compiled (
84
+ act, weight_reactant, x_reactant, bias_reactant, conv_dims
85
+ )
86
+ luxlib_res = fused_conv_bias_activation (act, weight, x, bias, conv_dims)
87
+
88
+ @test reactant_res ≈ luxlib_res broken = (act === gelu)
89
+ end
90
+
91
+ # TODO : test for gradients
92
+ end
93
+ end
0 commit comments