@@ -30,9 +30,12 @@ using LuxLib, Reactant, Enzyme, NNlib
30
30
x_ra = Reactant. to_rarray (x)
31
31
bias_ra = Reactant. to_rarray (bias)
32
32
33
- f_compile = Reactant. compile (
34
- fused_dense_bias_activation, (act, weight_ra, x_ra, bias_ra)
35
- )
33
+ f_compile = Reactant. with_config (;
34
+ dot_general_precision= PrecisionConfig. HIGHEST,
35
+ convolution_precision= PrecisionConfig. HIGHEST,
36
+ ) do
37
+ @compile fused_dense_bias_activation (act, weight_ra, x_ra, bias_ra)
38
+ end
36
39
37
40
y_res = fused_dense_bias_activation (act, weight, x, bias)
38
41
y_compile = f_compile (act, weight_ra, x_ra, bias_ra)
@@ -41,9 +44,13 @@ using LuxLib, Reactant, Enzyme, NNlib
41
44
42
45
@testset " Enzyme: fused_dense_bias_activation" begin
43
46
dw, dx, db = ∇fuseddense (act, weight, x, bias)
44
- ∇fuseddense_compiled = Reactant. compile (
45
- ∇fuseddense, (act, weight_ra, x_ra, bias_ra)
46
- )
47
+
48
+ ∇fuseddense_compiled = Reactant. with_config (;
49
+ dot_general_precision= PrecisionConfig. HIGHEST,
50
+ convolution_precision= PrecisionConfig. HIGHEST,
51
+ ) do
52
+ @compile ∇fuseddense (act, weight_ra, x_ra, bias_ra)
53
+ end
47
54
dw_compile, dx_compile, db_compile = ∇fuseddense_compiled (
48
55
act, weight_ra, x_ra, bias_ra
49
56
)
96
103
x_ra = Reactant. to_rarray (x)
97
104
b_ra = Reactant. to_rarray (b)
98
105
99
- f_compile = Reactant. compile (biasact, (act, x_ra, b_ra))
100
- f_compile!! = Reactant. compile (biasact!!, (act, x_ra, b_ra))
106
+ f_compile = Reactant. with_config (;
107
+ dot_general_precision= PrecisionConfig. HIGHEST,
108
+ convolution_precision= PrecisionConfig. HIGHEST,
109
+ ) do
110
+ @compile biasact (act, x_ra, b_ra)
111
+ end
112
+
113
+ f_compile!! = Reactant. with_config (;
114
+ dot_general_precision= PrecisionConfig. HIGHEST,
115
+ convolution_precision= PrecisionConfig. HIGHEST,
116
+ ) do
117
+ @compile biasact!! (act, x_ra, b_ra)
118
+ end
101
119
102
120
y_simple = biasact (act, x, b)
103
121
y_simple!! = biasact!! (act, x, b)
109
127
110
128
@testset " Enzyme: bias_activation" begin
111
129
∂x_enz, ∂b_enz = ∇biasact (act, x, b)
112
- ∇biasact_compiled = Reactant. compile (∇biasact, (act, x_ra, b_ra))
130
+ ∇biasact_compiled = Reactant. with_config (;
131
+ dot_general_precision= PrecisionConfig. HIGHEST,
132
+ convolution_precision= PrecisionConfig. HIGHEST,
133
+ ) do
134
+ @compile ∇biasact (act, x_ra, b_ra)
135
+ end
113
136
∂x_compile, ∂b_compile = ∇biasact_compiled (act, x_ra, b_ra)
114
137
115
138
@test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2
118
141
119
142
@testset " Enzyme: bias_activation!!" begin
120
143
∂x_enz!!, ∂b_enz!! = ∇biasact!! (act, x, b)
121
- ∇biasact!!_compiled = Reactant. compile (∇biasact!!, (act, x_ra, b_ra))
144
+ ∇biasact!!_compiled = Reactant. with_config (;
145
+ dot_general_precision= PrecisionConfig. HIGHEST,
146
+ convolution_precision= PrecisionConfig. HIGHEST,
147
+ ) do
148
+ @compile ∇biasact!! (act, x_ra, b_ra)
149
+ end
122
150
∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled (act, x_ra, b_ra)
123
151
124
152
@test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2
150
178
@testset " Activation: $act " for act in (
151
179
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
152
180
)
153
- f_compile = Reactant. compile (sumabs2, (act, x_act_ca))
154
- f_compile!! = Reactant. compile (sumabs2!!, (act, x_act_ca))
181
+ f_compile = Reactant. with_config (;
182
+ dot_general_precision= PrecisionConfig. HIGHEST,
183
+ convolution_precision= PrecisionConfig. HIGHEST,
184
+ ) do
185
+ @compile sumabs2 (act, x_act_ca)
186
+ end
187
+ f_compile!! = Reactant. with_config (;
188
+ dot_general_precision= PrecisionConfig. HIGHEST,
189
+ convolution_precision= PrecisionConfig. HIGHEST,
190
+ ) do
191
+ @compile sumabs2!! (act, x_act_ca)
192
+ end
155
193
156
194
y_simple = sumabs2 (act, x_act)
157
195
y_simple!! = sumabs2!! (act, x_act)
@@ -167,11 +205,21 @@ end
167
205
∂x_enz!! = Enzyme. make_zero (x_act)
168
206
Enzyme. autodiff (Reverse, sumabs2!!, Active, Const (act), Duplicated (x_act, ∂x_enz!!))
169
207
170
- ∇sumabs2_compiled = Reactant. compile (∇sumabs2, (act, x_act_ca))
171
- ∂x_compile = ∇sumabs2_compiled (act, x_act_ca)
208
+ ∇sumabs2 = Reactant. with_config (;
209
+ dot_general_precision= PrecisionConfig. HIGHEST,
210
+ convolution_precision= PrecisionConfig. HIGHEST,
211
+ ) do
212
+ @compile ∇sumabs2 (act, x_act_ca)
213
+ end
214
+ ∂x_compile = ∇sumabs2 (act, x_act_ca)
172
215
173
- ∇sumabs2!!_compiled = Reactant. compile (∇sumabs2!!, (act, x_act_ca))
174
- ∂x_compile!! = ∇sumabs2!!_compiled (act, x_act_ca)
216
+ ∇sumabs2!! = Reactant. with_config (;
217
+ dot_general_precision= PrecisionConfig. HIGHEST,
218
+ convolution_precision= PrecisionConfig. HIGHEST,
219
+ ) do
220
+ @compile ∇sumabs2!! (act, x_act_ca)
221
+ end
222
+ ∂x_compile!! = ∇sumabs2!! (act, x_act_ca)
175
223
176
224
@test ∂x_enz ≈ ∂x_compile atol = 1e-5 rtol = 1e-2
177
225
@test ∂x_enz!! ≈ ∂x_compile!! atol = 1e-5 rtol = 1e-2
@@ -194,10 +242,12 @@ end
194
242
195
243
conv_dims = DenseConvDims (x, weight; stride, padding, dilation, groups)
196
244
197
- fused_conv_compiled = Reactant. compile (
198
- fused_conv_bias_activation,
199
- (act, weight_reactant, x_reactant, bias_reactant, conv_dims),
200
- )
245
+ fused_conv_compiled = Reactant. with_config (;
246
+ dot_general_precision= PrecisionConfig. HIGHEST,
247
+ convolution_precision= PrecisionConfig. HIGHEST,
248
+ ) do
249
+ @compile fused_conv_bias_activation (act, weight_reactant, x_reactant, bias_reactant, conv_dims)
250
+ end
201
251
202
252
reactant_res = fused_conv_compiled (
203
253
act, weight_reactant, x_reactant, bias_reactant, conv_dims
0 commit comments