@@ -3,6 +3,76 @@ using LuxLib, Reactant, Enzyme, NNlib
3
3
@testset " Fused Dense" begin end
4
4
5
5
@testset " Bias Activation" begin
6
+ biasact (act, x, b) = bias_activation (act, x, b)
7
+ sumabs2biasact (act, x, b) = sum (abs2, biasact (act, x, b))
8
+ biasact!! (act, x, b) = bias_activation!! (act, copy (x), b)
9
+ sumabs2biasact!! (act, x, b) = sum (abs2, biasact!! (act, x, b))
10
+
11
+ function ∇biasact (act, x, b)
12
+ dx = Enzyme. make_zero (x)
13
+ db = Enzyme. make_zero (b)
14
+ Enzyme. autodiff (
15
+ Reverse,
16
+ sumabs2biasact,
17
+ Active,
18
+ Const (act),
19
+ Duplicated (x, dx),
20
+ Duplicated (b, db),
21
+ )
22
+ return dx, db
23
+ end
24
+
25
+ function ∇biasact!! (act, x, b)
26
+ dx = Enzyme. make_zero (x)
27
+ db = Enzyme. make_zero (b)
28
+ Enzyme. autodiff (
29
+ Reverse,
30
+ sumabs2biasact!!,
31
+ Active,
32
+ Const (act),
33
+ Duplicated (x, dx),
34
+ Duplicated (b, db),
35
+ )
36
+ return dx, db
37
+ end
38
+
39
+ @testset for act in (identity, relu, sigmoid, tanh, gelu)
40
+ x = randn (Float32, 10 , 10 )
41
+ b = randn (Float32, 10 )
42
+
43
+ x_ra = Reactant. ConcreteRArray (x)
44
+ b_ra = Reactant. ConcreteRArray (b)
45
+
46
+ f_compile = Reactant. compile (biasact, (act, x_ra, b_ra))
47
+ f_compile!! = Reactant. compile (biasact!!, (act, x_ra, b_ra))
48
+
49
+ y_simple = biasact (act, x, b)
50
+ y_simple!! = biasact!! (act, x, b)
51
+ y_compile = f_compile (act, x_ra, b_ra)
52
+ y_compile!! = f_compile!! (act, x_ra, b_ra)
53
+
54
+ @test y_simple ≈ y_compile broken = (act === gelu)
55
+ @test y_simple!! ≈ y_compile!! broken = (act === gelu)
56
+
57
+ # FIXME : Seems broken currently
58
+ @testset " Enzyme: bias_activation" begin
59
+ ∂x_enz, ∂b_enz = ∇biasact (act, x, b)
60
+ ∇biasact_compiled = Reactant. compile (∇biasact, (act, x_ra, b_ra))
61
+ ∂x_compile, ∂b_compile = ∇biasact_compiled (act, x_ra, b_ra)
62
+
63
+ @test ∂x_enz ≈ ∂x_compile broken = (act === gelu)
64
+ @test ∂b_enz ≈ ∂b_compile broken = (act === gelu)
65
+ end
66
+
67
+ @testset " Enzyme: bias_activation!!" begin
68
+ ∂x_enz!!, ∂b_enz!! = ∇biasact!! (act, x, b)
69
+ ∇biasact!!_compiled = Reactant. compile (∇biasact!!, (act, x_ra, b_ra))
70
+ ∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled (act, x_ra, b_ra)
71
+
72
+ @test ∂x_enz!! ≈ ∂x_compile!! broken = (act === gelu)
73
+ @test ∂b_enz!! ≈ ∂b_compile!! broken = (act === gelu)
74
+ end
75
+ end
6
76
end
7
77
8
78
@testset " Fast Activation" begin
18
88
19
89
function ∇sumabs2!! (f, x)
20
90
dx = Enzyme. make_zero (x)
21
- Enzyme. autodiff (Reverse, sumabs2, Active, Const (f), Duplicated (x, dx))
91
+ Enzyme. autodiff (Reverse, sumabs2!! , Active, Const (f), Duplicated (x, dx))
22
92
return dx
23
93
end
24
94
28
98
@testset " Activation: $act " for act in (
29
99
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
30
100
)
31
- f_compile = Reactant. compile (sumabs2, (act, x_act ))
32
- f_compile!! = Reactant. compile (sumabs2!!, (act, x_act ))
101
+ f_compile = Reactant. compile (sumabs2, (act, x_act_ca ))
102
+ f_compile!! = Reactant. compile (sumabs2!!, (act, x_act_ca ))
33
103
34
104
y_simple = sumabs2 (act, x_act)
35
105
y_simple!! = sumabs2!! (act, x_act)
0 commit comments