@@ -30,30 +30,26 @@ using LuxLib, Reactant, Enzyme, NNlib
30
30
x_ra = Reactant. to_rarray (x)
31
31
bias_ra = Reactant. to_rarray (bias)
32
32
33
- f_compile = Reactant. with_config (;
33
+ y_compile = Reactant. with_config (;
34
34
dot_general_precision= PrecisionConfig. HIGHEST,
35
35
convolution_precision= PrecisionConfig. HIGHEST,
36
36
) do
37
- @compile fused_dense_bias_activation (act, weight_ra, x_ra, bias_ra)
37
+ @jit fused_dense_bias_activation (act, weight_ra, x_ra, bias_ra)
38
38
end
39
39
40
40
y_res = fused_dense_bias_activation (act, weight, x, bias)
41
- y_compile = f_compile (act, weight_ra, x_ra, bias_ra)
42
41
43
42
@test y_res ≈ y_compile atol = 1e-5 rtol = 1e-2
44
43
45
44
@testset " Enzyme: fused_dense_bias_activation" begin
46
45
dw, dx, db = ∇fuseddense (act, weight, x, bias)
47
46
48
- ∇fuseddense_compiled = Reactant. with_config (;
47
+ dw_compile, dx_compile, db_compile = Reactant. with_config (;
49
48
dot_general_precision= PrecisionConfig. HIGHEST,
50
49
convolution_precision= PrecisionConfig. HIGHEST,
51
50
) do
52
- @compile ∇fuseddense (act, weight_ra, x_ra, bias_ra)
51
+ @jit ∇fuseddense (act, weight_ra, x_ra, bias_ra)
53
52
end
54
- dw_compile, dx_compile, db_compile = ∇fuseddense_compiled (
55
- act, weight_ra, x_ra, bias_ra
56
- )
57
53
58
54
@test dw ≈ dw_compile atol = 1e-5 rtol = 1e-2
59
55
@test dx ≈ dx_compile atol = 1e-5 rtol = 1e-2
103
99
x_ra = Reactant. to_rarray (x)
104
100
b_ra = Reactant. to_rarray (b)
105
101
106
- f_compile = Reactant. with_config (;
102
+ y_compile = Reactant. with_config (;
107
103
dot_general_precision= PrecisionConfig. HIGHEST,
108
104
convolution_precision= PrecisionConfig. HIGHEST,
109
105
) do
110
- @compile biasact (act, x_ra, b_ra)
106
+ @jit biasact (act, x_ra, b_ra)
111
107
end
112
108
113
- f_compile !! = Reactant. with_config (;
109
+ y_compile !! = Reactant. with_config (;
114
110
dot_general_precision= PrecisionConfig. HIGHEST,
115
111
convolution_precision= PrecisionConfig. HIGHEST,
116
112
) do
117
- @compile biasact!! (act, x_ra, b_ra)
113
+ @jit biasact!! (act, x_ra, b_ra)
118
114
end
119
115
120
116
y_simple = biasact (act, x, b)
121
117
y_simple!! = biasact!! (act, x, b)
122
- y_compile = f_compile (act, x_ra, b_ra)
123
- y_compile!! = f_compile!! (act, x_ra, b_ra)
124
118
125
119
@test y_simple ≈ y_compile atol = 1e-5 rtol = 1e-2
126
120
@test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2
127
121
128
122
@testset " Enzyme: bias_activation" begin
129
123
∂x_enz, ∂b_enz = ∇biasact (act, x, b)
130
- ∇biasact_compiled = Reactant. with_config (;
124
+ ∂x_compile, ∂b_compile = Reactant. with_config (;
131
125
dot_general_precision= PrecisionConfig. HIGHEST,
132
126
convolution_precision= PrecisionConfig. HIGHEST,
133
127
) do
134
- @compile ∇biasact (act, x_ra, b_ra)
128
+ @jit ∇biasact (act, x_ra, b_ra)
135
129
end
136
- ∂x_compile, ∂b_compile = ∇biasact_compiled (act, x_ra, b_ra)
137
130
138
131
@test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2
139
132
@test ∂b_enz ≈ ∂b_compile atol = 1e-5 rtol = 1e-2
140
133
end
141
134
142
135
@testset " Enzyme: bias_activation!!" begin
143
136
∂x_enz!!, ∂b_enz!! = ∇biasact!! (act, x, b)
144
- ∇biasact!!_compiled = Reactant. with_config (;
137
+ ∂x_compile!!, ∂b_compile!! = Reactant. with_config (;
145
138
dot_general_precision= PrecisionConfig. HIGHEST,
146
139
convolution_precision= PrecisionConfig. HIGHEST,
147
140
) do
148
- @compile ∇biasact!! (act, x_ra, b_ra)
141
+ @jit ∇biasact!! (act, x_ra, b_ra)
149
142
end
150
- ∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled (act, x_ra, b_ra)
151
143
152
144
@test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2
153
145
@test ∂b_enz!! ≈ ∂b_compile!! atol = 1e-5 rtol = 1e-2
@@ -178,24 +170,21 @@ end
178
170
@testset " Activation: $act " for act in (
179
171
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
180
172
)
181
- f_compile = Reactant. with_config (;
173
+ y_simple = sumabs2 (act, x_act)
174
+ y_simple!! = sumabs2!! (act, x_act)
175
+ y_compile = Reactant. with_config (;
182
176
dot_general_precision= PrecisionConfig. HIGHEST,
183
177
convolution_precision= PrecisionConfig. HIGHEST,
184
178
) do
185
- @compile sumabs2 (act, x_act_ca)
179
+ @jit sumabs2 (act, x_act_ca)
186
180
end
187
- f_compile !! = Reactant. with_config (;
181
+ y_compile !! = Reactant. with_config (;
188
182
dot_general_precision= PrecisionConfig. HIGHEST,
189
183
convolution_precision= PrecisionConfig. HIGHEST,
190
184
) do
191
- @compile sumabs2!! (act, x_act_ca)
185
+ @jit sumabs2!! (act, x_act_ca)
192
186
end
193
187
194
- y_simple = sumabs2 (act, x_act)
195
- y_simple!! = sumabs2!! (act, x_act)
196
- y_compile = f_compile (act, x_act_ca)
197
- y_compile!! = f_compile!! (act, x_act_ca)
198
-
199
188
@test y_simple ≈ y_compile atol = 1e-5 rtol = 1e-2
200
189
@test y_simple!! ≈ y_compile!! atol = 1e-5 rtol = 1e-2
201
190
@@ -205,21 +194,19 @@ end
205
194
∂x_enz!! = Enzyme. make_zero (x_act)
206
195
Enzyme. autodiff (Reverse, sumabs2!!, Active, Const (act), Duplicated (x_act, ∂x_enz!!))
207
196
208
- ∇sumabs2 = Reactant. with_config (;
197
+ ∂x_compile = Reactant. with_config (;
209
198
dot_general_precision= PrecisionConfig. HIGHEST,
210
199
convolution_precision= PrecisionConfig. HIGHEST,
211
200
) do
212
- @compile ∇sumabs2 (act, x_act_ca)
201
+ @jit ∇sumabs2 (act, x_act_ca)
213
202
end
214
- ∂x_compile = ∇sumabs2 (act, x_act_ca)
215
203
216
- ∇sumabs2 !! = Reactant. with_config (;
204
+ ∂x_compile !! = Reactant. with_config (;
217
205
dot_general_precision= PrecisionConfig. HIGHEST,
218
206
convolution_precision= PrecisionConfig. HIGHEST,
219
207
) do
220
- @compile ∇sumabs2!! (act, x_act_ca)
208
+ @jit ∇sumabs2!! (act, x_act_ca)
221
209
end
222
- ∂x_compile!! = ∇sumabs2!! (act, x_act_ca)
223
210
224
211
@test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2
225
212
@test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2
@@ -242,16 +229,15 @@ end
242
229
243
230
conv_dims = DenseConvDims (x, weight; stride, padding, dilation, groups)
244
231
245
- fused_conv_compiled = Reactant. with_config (;
232
+ reactant_res = Reactant. with_config (;
246
233
dot_general_precision= PrecisionConfig. HIGHEST,
247
234
convolution_precision= PrecisionConfig. HIGHEST,
248
235
) do
249
- @compile fused_conv_bias_activation (act, weight_reactant, x_reactant, bias_reactant, conv_dims)
236
+ @jit fused_conv_bias_activation (
237
+ act, weight_reactant, x_reactant, bias_reactant, conv_dims
238
+ )
250
239
end
251
240
252
- reactant_res = fused_conv_compiled (
253
- act, weight_reactant, x_reactant, bias_reactant, conv_dims
254
- )
255
241
luxlib_res = fused_conv_bias_activation (act, weight, x, bias, conv_dims)
256
242
257
243
@test reactant_res ≈ luxlib_res atol = 1e-5 rtol = 1e-2
0 commit comments