Skip to content

Commit 6abff33

Browse files
committed
test: add tolerances
1 parent d5c5cf1 commit 6abff33

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

test/nn/lux.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays
22

3-
# Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the
4-
# training loop manually:
53
function crossentropy(ŷ, y)
64
logŷ = log.(ŷ)
75
result = y .* logŷ
@@ -10,7 +8,7 @@ end
108

119
function loss_function(model, x, y, ps, st)
1210
y_hat, _ = model(x, ps, st)
13-
# return CrossEntropyLoss()(y_hat, y)
11+
# return CrossEntropyLoss()(y_hat, y) # <-- needs handling of xlogx xlogy from LuxOps
1412
return crossentropy(y_hat, y)
1513
end
1614

@@ -70,8 +68,8 @@ end
7068

7169
res_reactant, dps_reactant = compiled_gradient(cmodel, cnoisy, ctarget, cps, cst2)
7270

73-
@test res res_reactant
71+
@test res res_reactant atol = 1e-5 rtol = 1e-2
7472
for (dps1, dps2) in zip(fleaves(dps), fleaves(dps_reactant))
75-
@test dps1 dps2
73+
@test dps1 dps2 atol = 1e-5 rtol = 1e-2
7674
end
7775
end

test/nn/luxlib.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ using LuxLib, Reactant, Enzyme, NNlib
3737
y_res = fused_dense_bias_activation(act, weight, x, bias)
3838
y_compile = f_compile(act, weight_ra, x_ra, bias_ra)
3939

40-
@test y_res y_compile
40+
@test y_res y_compile atol = 1e-5 rtol = 1e-2
4141

4242
@testset "Enzyme: fused_dense_bias_activation" begin
4343
dw, dx, db = ∇fuseddense(act, weight, x, bias)
@@ -48,9 +48,9 @@ using LuxLib, Reactant, Enzyme, NNlib
4848
act, weight_ra, x_ra, bias_ra
4949
)
5050

51-
@test dw dw_compile
52-
@test dx dx_compile
53-
has_bias && @test db db_compile
51+
@test dw dw_compile atol = 1e-5 rtol = 1e-2
52+
@test dx dx_compile atol = 1e-5 rtol = 1e-2
53+
has_bias && @test db db_compile atol = 1e-5 rtol = 1e-2
5454
end
5555
end
5656
end
@@ -104,25 +104,25 @@ end
104104
y_compile = f_compile(act, x_ra, b_ra)
105105
y_compile!! = f_compile!!(act, x_ra, b_ra)
106106

107-
@test y_simple y_compile
108-
@test y_simple!! y_compile!!
107+
@test y_simple y_compile atol = 1e-5 rtol = 1e-2
108+
@test y_simple!! y_compile!! atol = 1e-5 rtol = 1e-2
109109

110110
@testset "Enzyme: bias_activation" begin
111111
∂x_enz, ∂b_enz = ∇biasact(act, x, b)
112112
∇biasact_compiled = Reactant.compile(∇biasact, (act, x_ra, b_ra))
113113
∂x_compile, ∂b_compile = ∇biasact_compiled(act, x_ra, b_ra)
114114

115-
@test ∂x_enz ∂x_compile
116-
@test ∂b_enz ∂b_compile
115+
@test ∂x_enz ∂x_compile atol = 1e-5 rtol = 1e-2
116+
@test ∂b_enz ∂b_compile atol = 1e-5 rtol = 1e-2
117117
end
118118

119119
@testset "Enzyme: bias_activation!!" begin
120120
∂x_enz!!, ∂b_enz!! = ∇biasact!!(act, x, b)
121121
∇biasact!!_compiled = Reactant.compile(∇biasact!!, (act, x_ra, b_ra))
122122
∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled(act, x_ra, b_ra)
123123

124-
@test ∂x_enz!! ∂x_compile!!
125-
@test ∂b_enz!! ∂b_compile!!
124+
@test ∂x_enz!! ∂x_compile!! atol = 1e-5 rtol = 1e-2
125+
@test ∂b_enz!! ∂b_compile!! atol = 1e-5 rtol = 1e-2
126126
end
127127
end
128128
end
@@ -158,8 +158,8 @@ end
158158
y_compile = f_compile(act, x_act_ca)
159159
y_compile!! = f_compile!!(act, x_act_ca)
160160

161-
@test y_simple y_compile
162-
@test y_simple!! y_compile!!
161+
@test y_simple y_compile atol = 1e-5 rtol = 1e-2
162+
@test y_simple!! y_compile!! atol = 1e-5 rtol = 1e-2
163163

164164
∂x_enz = Enzyme.make_zero(x_act)
165165
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
@@ -173,8 +173,8 @@ end
173173
∇sumabs2!!_compiled = Reactant.compile(∇sumabs2!!, (act, x_act_ca))
174174
∂x_compile!! = ∇sumabs2!!_compiled(act, x_act_ca)
175175

176-
@test ∂x_enz ∂x_compile
177-
@test ∂x_enz!! ∂x_compile!!
176+
@test ∂x_enz ∂x_compile atol = 1e-5 rtol = 1e-2
177+
@test ∂x_enz!! ∂x_compile!! atol = 1e-5 rtol = 1e-2
178178
end
179179
end
180180

@@ -207,7 +207,7 @@ end
207207
)
208208
luxlib_res = fused_conv_bias_activation(act, weight, x, bias, conv_dims)
209209

210-
@test reactant_res luxlib_res
210+
@test reactant_res luxlib_res atol = 1e-5 rtol = 1e-2
211211
end
212212

213213
# TODO: test for gradients

0 commit comments

Comments
 (0)