Skip to content

Commit 3c15405

Browse files
chore: fmt
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 4c3cf10 commit 3c15405

File tree

1 file changed

+26
-40
lines changed

1 file changed

+26
-40
lines changed

test/nn/luxlib.jl

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,30 +30,26 @@ using LuxLib, Reactant, Enzyme, NNlib
3030
x_ra = Reactant.to_rarray(x)
3131
bias_ra = Reactant.to_rarray(bias)
3232

33-
f_compile = Reactant.with_config(;
33+
y_compile = Reactant.with_config(;
3434
dot_general_precision=PrecisionConfig.HIGHEST,
3535
convolution_precision=PrecisionConfig.HIGHEST,
3636
) do
37-
@compile fused_dense_bias_activation(act, weight_ra, x_ra, bias_ra)
37+
@jit fused_dense_bias_activation(act, weight_ra, x_ra, bias_ra)
3838
end
3939

4040
y_res = fused_dense_bias_activation(act, weight, x, bias)
41-
y_compile = f_compile(act, weight_ra, x_ra, bias_ra)
4241

4342
@test y_res y_compile atol = 1e-5 rtol = 1e-2
4443

4544
@testset "Enzyme: fused_dense_bias_activation" begin
4645
dw, dx, db = ∇fuseddense(act, weight, x, bias)
4746

48-
∇fuseddense_compiled = Reactant.with_config(;
47+
dw_compile, dx_compile, db_compile = Reactant.with_config(;
4948
dot_general_precision=PrecisionConfig.HIGHEST,
5049
convolution_precision=PrecisionConfig.HIGHEST,
5150
) do
52-
@compile ∇fuseddense(act, weight_ra, x_ra, bias_ra)
51+
@jit ∇fuseddense(act, weight_ra, x_ra, bias_ra)
5352
end
54-
dw_compile, dx_compile, db_compile = ∇fuseddense_compiled(
55-
act, weight_ra, x_ra, bias_ra
56-
)
5753

5854
@test dw dw_compile atol = 1e-5 rtol = 1e-2
5955
@test dx dx_compile atol = 1e-5 rtol = 1e-2
@@ -103,51 +99,47 @@ end
10399
x_ra = Reactant.to_rarray(x)
104100
b_ra = Reactant.to_rarray(b)
105101

106-
f_compile = Reactant.with_config(;
102+
y_compile = Reactant.with_config(;
107103
dot_general_precision=PrecisionConfig.HIGHEST,
108104
convolution_precision=PrecisionConfig.HIGHEST,
109105
) do
110-
@compile biasact(act, x_ra, b_ra)
106+
@jit biasact(act, x_ra, b_ra)
111107
end
112108

113-
f_compile!! = Reactant.with_config(;
109+
y_compile!! = Reactant.with_config(;
114110
dot_general_precision=PrecisionConfig.HIGHEST,
115111
convolution_precision=PrecisionConfig.HIGHEST,
116112
) do
117-
@compile biasact!!(act, x_ra, b_ra)
113+
@jit biasact!!(act, x_ra, b_ra)
118114
end
119115

120116
y_simple = biasact(act, x, b)
121117
y_simple!! = biasact!!(act, x, b)
122-
y_compile = f_compile(act, x_ra, b_ra)
123-
y_compile!! = f_compile!!(act, x_ra, b_ra)
124118

125119
@test y_simple y_compile atol = 1e-5 rtol = 1e-2
126120
@test y_simple!! y_compile!! atol = 1e-5 rtol = 1e-2
127121

128122
@testset "Enzyme: bias_activation" begin
129123
∂x_enz, ∂b_enz = ∇biasact(act, x, b)
130-
∇biasact_compiled = Reactant.with_config(;
124+
∂x_compile, ∂b_compile = Reactant.with_config(;
131125
dot_general_precision=PrecisionConfig.HIGHEST,
132126
convolution_precision=PrecisionConfig.HIGHEST,
133127
) do
134-
@compile ∇biasact(act, x_ra, b_ra)
128+
@jit ∇biasact(act, x_ra, b_ra)
135129
end
136-
∂x_compile, ∂b_compile = ∇biasact_compiled(act, x_ra, b_ra)
137130

138131
@test ∂x_enz ∂x_compile atol = 1e-5 rtol = 1e-2
139132
@test ∂b_enz ∂b_compile atol = 1e-5 rtol = 1e-2
140133
end
141134

142135
@testset "Enzyme: bias_activation!!" begin
143136
∂x_enz!!, ∂b_enz!! = ∇biasact!!(act, x, b)
144-
∇biasact!!_compiled = Reactant.with_config(;
137+
∂x_compile!!, ∂b_compile!! = Reactant.with_config(;
145138
dot_general_precision=PrecisionConfig.HIGHEST,
146139
convolution_precision=PrecisionConfig.HIGHEST,
147140
) do
148-
@compile ∇biasact!!(act, x_ra, b_ra)
141+
@jit ∇biasact!!(act, x_ra, b_ra)
149142
end
150-
∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled(act, x_ra, b_ra)
151143

152144
@test ∂x_enz!! ∂x_compile!! atol = 1e-5 rtol = 1e-2
153145
@test ∂b_enz!! ∂b_compile!! atol = 1e-5 rtol = 1e-2
@@ -178,24 +170,21 @@ end
178170
@testset "Activation: $act" for act in (
179171
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
180172
)
181-
f_compile = Reactant.with_config(;
173+
y_simple = sumabs2(act, x_act)
174+
y_simple!! = sumabs2!!(act, x_act)
175+
y_compile = Reactant.with_config(;
182176
dot_general_precision=PrecisionConfig.HIGHEST,
183177
convolution_precision=PrecisionConfig.HIGHEST,
184178
) do
185-
@compile sumabs2(act, x_act_ca)
179+
@jit sumabs2(act, x_act_ca)
186180
end
187-
f_compile!! = Reactant.with_config(;
181+
y_compile!! = Reactant.with_config(;
188182
dot_general_precision=PrecisionConfig.HIGHEST,
189183
convolution_precision=PrecisionConfig.HIGHEST,
190184
) do
191-
@compile sumabs2!!(act, x_act_ca)
185+
@jit sumabs2!!(act, x_act_ca)
192186
end
193187

194-
y_simple = sumabs2(act, x_act)
195-
y_simple!! = sumabs2!!(act, x_act)
196-
y_compile = f_compile(act, x_act_ca)
197-
y_compile!! = f_compile!!(act, x_act_ca)
198-
199188
@test y_simple y_compile atol = 1e-5 rtol = 1e-2
200189
@test y_simple!! y_compile!! atol = 1e-5 rtol = 1e-2
201190

@@ -205,21 +194,19 @@ end
205194
∂x_enz!! = Enzyme.make_zero(x_act)
206195
Enzyme.autodiff(Reverse, sumabs2!!, Active, Const(act), Duplicated(x_act, ∂x_enz!!))
207196

208-
∇sumabs2 = Reactant.with_config(;
197+
∂x_compile = Reactant.with_config(;
209198
dot_general_precision=PrecisionConfig.HIGHEST,
210199
convolution_precision=PrecisionConfig.HIGHEST,
211200
) do
212-
@compile ∇sumabs2(act, x_act_ca)
201+
@jit ∇sumabs2(act, x_act_ca)
213202
end
214-
∂x_compile = ∇sumabs2(act, x_act_ca)
215203

216-
∇sumabs2!! = Reactant.with_config(;
204+
∂x_compile!! = Reactant.with_config(;
217205
dot_general_precision=PrecisionConfig.HIGHEST,
218206
convolution_precision=PrecisionConfig.HIGHEST,
219207
) do
220-
@compile ∇sumabs2!!(act, x_act_ca)
208+
@jit ∇sumabs2!!(act, x_act_ca)
221209
end
222-
∂x_compile!! = ∇sumabs2!!(act, x_act_ca)
223210

224211
@test ∂x_enz ∂x_compile atol = 1e-5 rtol = 1e-2
225212
@test ∂x_enz!! ∂x_compile!! atol = 1e-5 rtol = 1e-2
@@ -242,16 +229,15 @@ end
242229

243230
conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)
244231

245-
fused_conv_compiled = Reactant.with_config(;
232+
reactant_res = Reactant.with_config(;
246233
dot_general_precision=PrecisionConfig.HIGHEST,
247234
convolution_precision=PrecisionConfig.HIGHEST,
248235
) do
249-
@compile fused_conv_bias_activation(act, weight_reactant, x_reactant, bias_reactant, conv_dims)
236+
@jit fused_conv_bias_activation(
237+
act, weight_reactant, x_reactant, bias_reactant, conv_dims
238+
)
250239
end
251240

252-
reactant_res = fused_conv_compiled(
253-
act, weight_reactant, x_reactant, bias_reactant, conv_dims
254-
)
255241
luxlib_res = fused_conv_bias_activation(act, weight, x, bias, conv_dims)
256242

257243
@test reactant_res luxlib_res atol = 1e-5 rtol = 1e-2

0 commit comments

Comments
 (0)