Skip to content

Commit 4c3cf10

Browse files
committed
test: with highest precision for accuracy
1 parent aad1cfe commit 4c3cf10

File tree

2 files changed

+71
-21
lines changed

2 files changed

+71
-21
lines changed

test/nn/flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ using Reactant, Flux
2121
f = Reactant.compile((a, b) -> a(b), (cmodel, cnoisy))
2222

2323
comp = f(cmodel, cnoisy)
24-
@test origout comp atol = 1e-5 rtol = 1e-3
24+
@test origout comp atol = 1e-3 rtol = 1e-2
2525
end

test/nn/luxlib.jl

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,12 @@ using LuxLib, Reactant, Enzyme, NNlib
3030
x_ra = Reactant.to_rarray(x)
3131
bias_ra = Reactant.to_rarray(bias)
3232

33-
f_compile = Reactant.compile(
34-
fused_dense_bias_activation, (act, weight_ra, x_ra, bias_ra)
35-
)
33+
f_compile = Reactant.with_config(;
34+
dot_general_precision=PrecisionConfig.HIGHEST,
35+
convolution_precision=PrecisionConfig.HIGHEST,
36+
) do
37+
@compile fused_dense_bias_activation(act, weight_ra, x_ra, bias_ra)
38+
end
3639

3740
y_res = fused_dense_bias_activation(act, weight, x, bias)
3841
y_compile = f_compile(act, weight_ra, x_ra, bias_ra)
@@ -41,9 +44,13 @@ using LuxLib, Reactant, Enzyme, NNlib
4144

4245
@testset "Enzyme: fused_dense_bias_activation" begin
4346
dw, dx, db = ∇fuseddense(act, weight, x, bias)
44-
∇fuseddense_compiled = Reactant.compile(
45-
∇fuseddense, (act, weight_ra, x_ra, bias_ra)
46-
)
47+
48+
∇fuseddense_compiled = Reactant.with_config(;
49+
dot_general_precision=PrecisionConfig.HIGHEST,
50+
convolution_precision=PrecisionConfig.HIGHEST,
51+
) do
52+
@compile ∇fuseddense(act, weight_ra, x_ra, bias_ra)
53+
end
4754
dw_compile, dx_compile, db_compile = ∇fuseddense_compiled(
4855
act, weight_ra, x_ra, bias_ra
4956
)
@@ -96,8 +103,19 @@ end
96103
x_ra = Reactant.to_rarray(x)
97104
b_ra = Reactant.to_rarray(b)
98105

99-
f_compile = Reactant.compile(biasact, (act, x_ra, b_ra))
100-
f_compile!! = Reactant.compile(biasact!!, (act, x_ra, b_ra))
106+
f_compile = Reactant.with_config(;
107+
dot_general_precision=PrecisionConfig.HIGHEST,
108+
convolution_precision=PrecisionConfig.HIGHEST,
109+
) do
110+
@compile biasact(act, x_ra, b_ra)
111+
end
112+
113+
f_compile!! = Reactant.with_config(;
114+
dot_general_precision=PrecisionConfig.HIGHEST,
115+
convolution_precision=PrecisionConfig.HIGHEST,
116+
) do
117+
@compile biasact!!(act, x_ra, b_ra)
118+
end
101119

102120
y_simple = biasact(act, x, b)
103121
y_simple!! = biasact!!(act, x, b)
@@ -109,7 +127,12 @@ end
109127

110128
@testset "Enzyme: bias_activation" begin
111129
∂x_enz, ∂b_enz = ∇biasact(act, x, b)
112-
∇biasact_compiled = Reactant.compile(∇biasact, (act, x_ra, b_ra))
130+
∇biasact_compiled = Reactant.with_config(;
131+
dot_general_precision=PrecisionConfig.HIGHEST,
132+
convolution_precision=PrecisionConfig.HIGHEST,
133+
) do
134+
@compile ∇biasact(act, x_ra, b_ra)
135+
end
113136
∂x_compile, ∂b_compile = ∇biasact_compiled(act, x_ra, b_ra)
114137

115138
@test ∂x_enz ∂x_compile atol = 1e-5 rtol = 1e-2
@@ -118,7 +141,12 @@ end
118141

119142
@testset "Enzyme: bias_activation!!" begin
120143
∂x_enz!!, ∂b_enz!! = ∇biasact!!(act, x, b)
121-
∇biasact!!_compiled = Reactant.compile(∇biasact!!, (act, x_ra, b_ra))
144+
∇biasact!!_compiled = Reactant.with_config(;
145+
dot_general_precision=PrecisionConfig.HIGHEST,
146+
convolution_precision=PrecisionConfig.HIGHEST,
147+
) do
148+
@compile ∇biasact!!(act, x_ra, b_ra)
149+
end
122150
∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled(act, x_ra, b_ra)
123151

124152
@test ∂x_enz!! ∂x_compile!! atol = 1e-5 rtol = 1e-2
@@ -150,8 +178,18 @@ end
150178
@testset "Activation: $act" for act in (
151179
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
152180
)
153-
f_compile = Reactant.compile(sumabs2, (act, x_act_ca))
154-
f_compile!! = Reactant.compile(sumabs2!!, (act, x_act_ca))
181+
f_compile = Reactant.with_config(;
182+
dot_general_precision=PrecisionConfig.HIGHEST,
183+
convolution_precision=PrecisionConfig.HIGHEST,
184+
) do
185+
@compile sumabs2(act, x_act_ca)
186+
end
187+
f_compile!! = Reactant.with_config(;
188+
dot_general_precision=PrecisionConfig.HIGHEST,
189+
convolution_precision=PrecisionConfig.HIGHEST,
190+
) do
191+
@compile sumabs2!!(act, x_act_ca)
192+
end
155193

156194
y_simple = sumabs2(act, x_act)
157195
y_simple!! = sumabs2!!(act, x_act)
@@ -167,11 +205,21 @@ end
167205
∂x_enz!! = Enzyme.make_zero(x_act)
168206
Enzyme.autodiff(Reverse, sumabs2!!, Active, Const(act), Duplicated(x_act, ∂x_enz!!))
169207

170-
∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca))
171-
∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
208+
∇sumabs2 = Reactant.with_config(;
209+
dot_general_precision=PrecisionConfig.HIGHEST,
210+
convolution_precision=PrecisionConfig.HIGHEST,
211+
) do
212+
@compile ∇sumabs2(act, x_act_ca)
213+
end
214+
∂x_compile = ∇sumabs2(act, x_act_ca)
172215

173-
∇sumabs2!!_compiled = Reactant.compile(∇sumabs2!!, (act, x_act_ca))
174-
∂x_compile!! = ∇sumabs2!!_compiled(act, x_act_ca)
216+
∇sumabs2!! = Reactant.with_config(;
217+
dot_general_precision=PrecisionConfig.HIGHEST,
218+
convolution_precision=PrecisionConfig.HIGHEST,
219+
) do
220+
@compile ∇sumabs2!!(act, x_act_ca)
221+
end
222+
∂x_compile!! = ∇sumabs2!!(act, x_act_ca)
175223

176224
@test ∂x_enz ∂x_compile atol = 1e-5 rtol = 1e-2
177225
@test ∂x_enz!! ∂x_compile!! atol = 1e-5 rtol = 1e-2
@@ -194,10 +242,12 @@ end
194242

195243
conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)
196244

197-
fused_conv_compiled = Reactant.compile(
198-
fused_conv_bias_activation,
199-
(act, weight_reactant, x_reactant, bias_reactant, conv_dims),
200-
)
245+
fused_conv_compiled = Reactant.with_config(;
246+
dot_general_precision=PrecisionConfig.HIGHEST,
247+
convolution_precision=PrecisionConfig.HIGHEST,
248+
) do
249+
@compile fused_conv_bias_activation(act, weight_reactant, x_reactant, bias_reactant, conv_dims)
250+
end
201251

202252
reactant_res = fused_conv_compiled(
203253
act, weight_reactant, x_reactant, bias_reactant, conv_dims

0 commit comments

Comments
 (0)