@@ -8,8 +8,7 @@ function seeded_autodiff_thunk(
8
8
forward, reverse = autodiff_thunk (rmode, FA, RA, typeof .(args)... )
9
9
tape, result, shadow_result = forward (f, args... )
10
10
if RA <: Active
11
- dresult_righttype = convert (typeof (result), dresult)
12
- dinputs = only (reverse (f, args... , dresult_righttype, tape))
11
+ dinputs = only (reverse (f, args... , dresult, tape))
13
12
else
14
13
shadow_result .+ = dresult # TODO : generalize beyond arrays
15
14
dinputs = only (reverse (f, args... , tape))
@@ -32,8 +31,7 @@ function batch_seeded_autodiff_thunk(
32
31
forward, reverse = autodiff_thunk (rmode_rightwidth, FA, RA, typeof .(args)... )
33
32
tape, result, shadow_results = forward (f, args... )
34
33
if RA <: Active
35
- dresults_righttype = map (Fix1 (convert, typeof (result)), dresults)
36
- dinputs = only (reverse (f, args... , dresults_righttype, tape))
34
+ dinputs = only (reverse (f, args... , dresults, tape))
37
35
else
38
36
foreach (shadow_results, dresults) do d0, d
39
37
d0 .+ = d # use recursive_add here?
@@ -141,13 +139,12 @@ function DI.value_and_pullback!(
141
139
mode = reverse_split_withprimal (backend)
142
140
f_and_df = force_annotation (get_f_and_df (f, backend, mode))
143
141
RA = guess_activity (typeof (prep. y_example), mode)
144
- dx_righttype = convert ( typeof (x), only (tx) )
145
- make_zero! (dx_righttype )
142
+ dx = only (tx)
143
+ make_zero! (dx )
146
144
annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
147
145
_, result = seeded_autodiff_thunk (
148
- mode, only (ty), f_and_df, RA, Duplicated (x, dx_righttype ), annotated_contexts...
146
+ mode, only (ty), f_and_df, RA, Duplicated (x, dx ), annotated_contexts...
149
147
)
150
- copyto_if_different_addresses! (only (tx), dx_righttype)
151
148
return result, tx
152
149
end
153
150
@@ -163,13 +160,11 @@ function DI.value_and_pullback!(
163
160
mode = reverse_split_withprimal (backend)
164
161
f_and_df = force_annotation (get_f_and_df (f, backend, mode, Val (B)))
165
162
RA = batchify_activity (guess_activity (typeof (prep. y_example), mode), Val (B))
166
- tx_righttype = map (Fix1 (convert, typeof (x)), tx)
167
- make_zero! (tx_righttype)
163
+ make_zero! (tx)
168
164
annotated_contexts = translate (backend, mode, Val (B), contexts... )
169
165
_, result = batch_seeded_autodiff_thunk (
170
- mode, ty, f_and_df, RA, BatchDuplicated (x, tx_righttype ), annotated_contexts...
166
+ mode, ty, f_and_df, RA, BatchDuplicated (x, tx ), annotated_contexts...
171
167
)
172
- foreach (copyto_if_different_addresses!, tx, tx_righttype)
173
168
return result, tx
174
169
end
175
170
@@ -187,10 +182,73 @@ end
187
182
188
183
# # Gradient
189
184
190
- # ## Without preparation
185
+ function DI. prepare_gradient (
186
+ f:: F , :: AutoEnzyme{<:Union{ReverseMode,Nothing}} , x, contexts:: Vararg{DI.Context,C}
187
+ ) where {F,C}
188
+ return DI. NoGradientPrep ()
189
+ end
190
+
191
+ # ## Enzyme gradient API (only constants)
192
+
193
+ function DI. gradient (
194
+ f:: F ,
195
+ :: DI.NoGradientPrep ,
196
+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
197
+ x,
198
+ contexts:: Vararg{DI.Constant,C} ,
199
+ ) where {F,C}
200
+ mode = reverse_noprimal (backend)
201
+ f_and_df = get_f_and_df (f, backend, mode)
202
+ annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
203
+ grads = gradient (mode, f_and_df, x, annotated_contexts... )
204
+ return first (grads)
205
+ end
206
+
207
+ function DI. value_and_gradient (
208
+ f:: F ,
209
+ :: DI.NoGradientPrep ,
210
+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
211
+ x,
212
+ contexts:: Vararg{DI.Constant,C} ,
213
+ ) where {F,C}
214
+ mode = reverse_withprimal (backend)
215
+ f_and_df = get_f_and_df (f, backend, mode)
216
+ annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
217
+ grads, result = gradient (mode, f_and_df, x, annotated_contexts... )
218
+ return result, first (grads)
219
+ end
220
+
221
+ function DI. gradient! (
222
+ f:: F ,
223
+ grad,
224
+ :: DI.NoGradientPrep ,
225
+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
226
+ x,
227
+ ) where {F}
228
+ mode = reverse_noprimal (backend)
229
+ f_and_df = get_f_and_df (f, backend, mode)
230
+ gradient! (mode, grad, f_and_df, x)
231
+ return grad
232
+ end
233
+
234
+ function DI. value_and_gradient! (
235
+ f:: F ,
236
+ grad,
237
+ :: DI.NoGradientPrep ,
238
+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
239
+ x,
240
+ ) where {F}
241
+ mode = reverse_withprimal (backend)
242
+ f_and_df = get_f_and_df (f, backend, mode)
243
+ _, result = gradient! (mode, grad, f_and_df, x)
244
+ return result, grad
245
+ end
246
+
247
+ # ## Generic
191
248
192
249
function DI. gradient (
193
250
f:: F ,
251
+ :: DI.NoGradientPrep ,
194
252
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
195
253
x,
196
254
contexts:: Vararg{DI.Context,C} ,
213
271
214
272
function DI. value_and_gradient (
215
273
f:: F ,
274
+ :: DI.NoGradientPrep ,
216
275
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
217
276
x,
218
277
contexts:: Vararg{DI.Context,C} ,
@@ -233,73 +292,34 @@ function DI.value_and_gradient(
233
292
end
234
293
end
235
294
236
- # ## With preparation
237
-
238
- struct EnzymeGradientPrep{G} <: DI.GradientPrep
239
- grad_righttype:: G
240
- end
241
-
242
- function DI. prepare_gradient (
243
- f:: F , :: AutoEnzyme{<:Union{ReverseMode,Nothing}} , x, contexts:: Vararg{DI.Context,C}
244
- ) where {F,C}
245
- grad_righttype = make_zero (x)
246
- return EnzymeGradientPrep (grad_righttype)
247
- end
248
-
249
- function DI. gradient (
250
- f:: F ,
251
- :: EnzymeGradientPrep ,
252
- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
253
- x,
254
- contexts:: Vararg{DI.Context,C} ,
255
- ) where {F,C}
256
- return DI. gradient (f, backend, x, contexts... )
257
- end
258
-
259
295
function DI. gradient! (
260
296
f:: F ,
261
297
grad,
262
- prep :: EnzymeGradientPrep ,
298
+ :: DI.NoGradientPrep ,
263
299
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
264
300
x,
265
301
contexts:: Vararg{DI.Context,C} ,
266
302
) where {F,C}
267
303
mode = reverse_noprimal (backend)
268
304
f_and_df = get_f_and_df (f, backend, mode)
269
- grad_righttype = grad isa typeof (x) ? grad : prep. grad_righttype
270
- make_zero! (grad_righttype)
271
305
annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
272
- autodiff (mode, f_and_df, Active, Duplicated (x, grad_righttype), annotated_contexts ... )
273
- copyto_if_different_addresses! ( grad, grad_righttype )
306
+ make_zero! (grad )
307
+ autodiff (mode, f_and_df, Active, Duplicated (x, grad), annotated_contexts ... )
274
308
return grad
275
309
end
276
310
277
- function DI. value_and_gradient (
278
- f:: F ,
279
- :: EnzymeGradientPrep ,
280
- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
281
- x,
282
- contexts:: Vararg{DI.Context,C} ,
283
- ) where {F,C}
284
- return DI. value_and_gradient (f, backend, x, contexts... )
285
- end
286
-
287
311
function DI. value_and_gradient! (
288
312
f:: F ,
289
313
grad,
290
- prep :: EnzymeGradientPrep ,
314
+ :: DI.NoGradientPrep ,
291
315
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
292
316
x,
293
317
contexts:: Vararg{DI.Context,C} ,
294
318
) where {F,C}
295
319
mode = reverse_withprimal (backend)
296
320
f_and_df = get_f_and_df (f, backend, mode)
297
- grad_righttype = grad isa typeof (x) ? grad : prep. grad_righttype
298
- make_zero! (grad_righttype)
299
321
annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
300
- _, y = autodiff (
301
- mode, f_and_df, Active, Duplicated (x, grad_righttype), annotated_contexts...
302
- )
303
- copyto_if_different_addresses! (grad, grad_righttype)
322
+ make_zero! (grad)
323
+ _, y = autodiff (mode, f_and_df, Active, Duplicated (x, grad), annotated_contexts... )
304
324
return y, grad
305
325
end
0 commit comments