@@ -37,7 +37,7 @@ using LuxLib, Reactant, Enzyme, NNlib
37
37
y_res = fused_dense_bias_activation (act, weight, x, bias)
38
38
y_compile = f_compile (act, weight_ra, x_ra, bias_ra)
39
39
40
- @test y_res ≈ y_compile broken = (act === gelu)
40
+ @test y_res ≈ y_compile
41
41
42
42
@testset " Enzyme: fused_dense_bias_activation" begin
43
43
dw, dx, db = ∇fuseddense (act, weight, x, bias)
@@ -48,9 +48,9 @@ using LuxLib, Reactant, Enzyme, NNlib
48
48
act, weight_ra, x_ra, bias_ra
49
49
)
50
50
51
- @test dw ≈ dw_compile broken = (act === gelu)
52
- @test dx ≈ dx_compile broken = (act === gelu)
53
- has_bias && @test db ≈ db_compile broken = (act === gelu)
51
+ @test dw ≈ dw_compile
52
+ @test dx ≈ dx_compile
53
+ has_bias && @test db ≈ db_compile
54
54
end
55
55
end
56
56
end
@@ -104,25 +104,25 @@ end
104
104
y_compile = f_compile (act, x_ra, b_ra)
105
105
y_compile!! = f_compile!! (act, x_ra, b_ra)
106
106
107
- @test y_simple ≈ y_compile broken = (act === gelu)
108
- @test y_simple!! ≈ y_compile!! broken = (act === gelu)
107
+ @test y_simple ≈ y_compile
108
+ @test y_simple!! ≈ y_compile!!
109
109
110
110
@testset " Enzyme: bias_activation" begin
111
111
∂x_enz, ∂b_enz = ∇biasact (act, x, b)
112
112
∇biasact_compiled = Reactant. compile (∇biasact, (act, x_ra, b_ra))
113
113
∂x_compile, ∂b_compile = ∇biasact_compiled (act, x_ra, b_ra)
114
114
115
- @test ∂x_enz ≈ ∂x_compile broken = (act === gelu)
116
- @test ∂b_enz ≈ ∂b_compile broken = (act === gelu)
115
+ @test ∂x_enz ≈ ∂x_compile
116
+ @test ∂b_enz ≈ ∂b_compile
117
117
end
118
118
119
119
@testset " Enzyme: bias_activation!!" begin
120
120
∂x_enz!!, ∂b_enz!! = ∇biasact!! (act, x, b)
121
121
∇biasact!!_compiled = Reactant. compile (∇biasact!!, (act, x_ra, b_ra))
122
122
∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled (act, x_ra, b_ra)
123
123
124
- @test ∂x_enz!! ≈ ∂x_compile!! broken = (act === gelu)
125
- @test ∂b_enz!! ≈ ∂b_compile!! broken = (act === gelu)
124
+ @test ∂x_enz!! ≈ ∂x_compile!!
125
+ @test ∂b_enz!! ≈ ∂b_compile!!
126
126
end
127
127
end
128
128
end
158
158
y_compile = f_compile (act, x_act_ca)
159
159
y_compile!! = f_compile!! (act, x_act_ca)
160
160
161
- @test y_simple ≈ y_compile broken = (act === gelu)
162
- @test y_simple!! ≈ y_compile!! broken = (act === gelu)
161
+ @test y_simple ≈ y_compile
162
+ @test y_simple!! ≈ y_compile!!
163
163
164
164
∂x_enz = Enzyme. make_zero (x_act)
165
165
Enzyme. autodiff (Reverse, sumabs2, Active, Const (act), Duplicated (x_act, ∂x_enz))
173
173
∇sumabs2!!_compiled = Reactant. compile (∇sumabs2!!, (act, x_act_ca))
174
174
∂x_compile!! = ∇sumabs2!!_compiled (act, x_act_ca)
175
175
176
- @test ∂x_enz ≈ ∂x_compile broken = (act === gelu)
177
- @test ∂x_enz!! ≈ ∂x_compile!! broken = (act === gelu)
176
+ @test ∂x_enz ≈ ∂x_compile
177
+ @test ∂x_enz!! ≈ ∂x_compile!!
178
178
end
179
179
end
180
180
207
207
)
208
208
luxlib_res = fused_conv_bias_activation (act, weight, x, bias, conv_dims)
209
209
210
- @test reactant_res ≈ luxlib_res broken = (act === gelu)
210
+ @test reactant_res ≈ luxlib_res
211
211
end
212
212
213
213
# TODO : test for gradients
0 commit comments