49
49
50
50
# # Pullback
51
51
52
+ struct EnzymeReverseOneArgPullbackPrep{Y} <: DI.PullbackPrep
53
+ y_example:: Y # useful to create return activity
54
+ end
55
+
52
56
function DI. prepare_pullback (
53
57
f:: F ,
54
58
:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
55
59
x,
56
60
ty:: NTuple ,
57
61
contexts:: Vararg{DI.Context,C} ,
58
62
) where {F,C}
59
- return DI. NoPullbackPrep ()
63
+ y = f (x, map (DI. unwrap, contexts)... )
64
+ return EnzymeReverseOneArgPullbackPrep (y)
60
65
end
61
66
62
67
# ## Out-of-place
63
68
64
69
function DI. value_and_pullback (
65
70
f:: F ,
66
- :: DI.NoPullbackPrep ,
71
+ prep :: EnzymeReverseOneArgPullbackPrep ,
67
72
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
68
73
x,
69
74
ty:: NTuple{1} ,
@@ -72,7 +77,7 @@ function DI.value_and_pullback(
72
77
mode = reverse_split_withprimal (backend)
73
78
f_and_df = force_annotation (get_f_and_df (f, backend, mode))
74
79
IA = guess_activity (typeof (x), mode)
75
- RA = guess_activity (eltype (ty ), mode)
80
+ RA = guess_activity (typeof (prep . y_example ), mode)
76
81
dx = make_zero (x)
77
82
annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
78
83
dinputs, result = seeded_autodiff_thunk (
88
93
89
94
function DI. value_and_pullback (
90
95
f:: F ,
91
- :: DI.NoPullbackPrep ,
96
+ prep :: EnzymeReverseOneArgPullbackPrep ,
92
97
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
93
98
x,
94
99
ty:: NTuple{B} ,
@@ -97,7 +102,7 @@ function DI.value_and_pullback(
97
102
mode = reverse_split_withprimal (backend)
98
103
f_and_df = force_annotation (get_f_and_df (f, backend, mode, Val (B)))
99
104
IA = batchify_activity (guess_activity (typeof (x), mode), Val (B))
100
- RA = batchify_activity (guess_activity (eltype (ty ), mode), Val (B))
105
+ RA = batchify_activity (guess_activity (typeof (prep . y_example ), mode), Val (B))
101
106
tx = ntuple (_ -> make_zero (x), Val (B))
102
107
annotated_contexts = translate (backend, mode, Val (B), contexts... )
103
108
dinputs, result = batch_seeded_autodiff_thunk (
113
118
114
119
function DI. pullback (
115
120
f:: F ,
116
- prep:: DI.NoPullbackPrep ,
121
+ prep:: EnzymeReverseOneArgPullbackPrep ,
117
122
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
118
123
x,
119
124
ty:: NTuple ,
@@ -127,51 +132,51 @@ end
127
132
function DI. value_and_pullback! (
128
133
f:: F ,
129
134
tx:: NTuple{1} ,
130
- :: DI.NoPullbackPrep ,
135
+ prep :: EnzymeReverseOneArgPullbackPrep ,
131
136
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
132
137
x,
133
138
ty:: NTuple{1} ,
134
139
contexts:: Vararg{DI.Context,C} ,
135
140
) where {F,C}
136
141
mode = reverse_split_withprimal (backend)
137
142
f_and_df = force_annotation (get_f_and_df (f, backend, mode))
138
- RA = guess_activity (eltype (ty ), mode)
143
+ RA = guess_activity (typeof (prep . y_example ), mode)
139
144
dx_righttype = convert (typeof (x), only (tx))
140
145
make_zero! (dx_righttype)
141
146
annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
142
147
_, result = seeded_autodiff_thunk (
143
148
mode, only (ty), f_and_df, RA, Duplicated (x, dx_righttype), annotated_contexts...
144
149
)
145
- only (tx) === dx_righttype || copyto ! (only (tx), dx_righttype)
150
+ copyto_if_different_addresses ! (only (tx), dx_righttype)
146
151
return result, tx
147
152
end
148
153
149
154
function DI. value_and_pullback! (
150
155
f:: F ,
151
156
tx:: NTuple{B} ,
152
- :: DI.NoPullbackPrep ,
157
+ prep :: EnzymeReverseOneArgPullbackPrep ,
153
158
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
154
159
x,
155
160
ty:: NTuple{B} ,
156
161
contexts:: Vararg{DI.Context,C} ,
157
162
) where {F,B,C}
158
163
mode = reverse_split_withprimal (backend)
159
164
f_and_df = force_annotation (get_f_and_df (f, backend, mode, Val (B)))
160
- RA = batchify_activity (guess_activity (eltype (ty ), mode), Val (B))
165
+ RA = batchify_activity (guess_activity (typeof (prep . y_example ), mode), Val (B))
161
166
tx_righttype = map (Fix1 (convert, typeof (x)), tx)
162
167
make_zero! (tx_righttype)
163
168
annotated_contexts = translate (backend, mode, Val (B), contexts... )
164
169
_, result = batch_seeded_autodiff_thunk (
165
170
mode, ty, f_and_df, RA, BatchDuplicated (x, tx_righttype), annotated_contexts...
166
171
)
167
- foreach (copyto !, tx, tx_righttype)
172
+ foreach (copyto_if_different_addresses !, tx, tx_righttype)
168
173
return result, tx
169
174
end
170
175
171
176
function DI. pullback! (
172
177
f:: F ,
173
178
tx:: NTuple ,
174
- prep:: DI.NoPullbackPrep ,
179
+ prep:: EnzymeReverseOneArgPullbackPrep ,
175
180
backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
176
181
x,
177
182
ty:: NTuple ,
@@ -265,7 +270,7 @@ function DI.gradient!(
265
270
make_zero! (grad_righttype)
266
271
annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
267
272
autodiff (mode, f_and_df, Active, Duplicated (x, grad_righttype), annotated_contexts... )
268
- grad === grad_righttype || copyto ! (grad, grad_righttype)
273
+ copyto_if_different_addresses ! (grad, grad_righttype)
269
274
return grad
270
275
end
271
276
@@ -295,70 +300,6 @@ function DI.value_and_gradient!(
295
300
_, y = autodiff (
296
301
mode, f_and_df, Active, Duplicated (x, grad_righttype), annotated_contexts...
297
302
)
298
- grad === grad_righttype || copyto ! (grad, grad_righttype)
303
+ copyto_if_different_addresses ! (grad, grad_righttype)
299
304
return y, grad
300
305
end
301
-
302
- # # Jacobian
303
-
304
- # TODO : does not support static arrays
305
-
306
- #=
307
- struct EnzymeReverseOneArgJacobianPrep{Sy,B} <:DI.JacobianPrep end
308
-
309
- function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
310
- return EnzymeReverseOneArgJacobianPrep{Sy,B}()
311
- end
312
-
313
- function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
314
- y = f(x)
315
- Sy = size(y)
316
- valB = to_val(DI.pick_batchsize(backend, y))
317
- return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
318
- end
319
-
320
- function DI.jacobian(
321
- f::F,
322
- ::EnzymeReverseOneArgJacobianPrep{Sy,B},
323
- backend::AutoEnzyme{<:ReverseMode,Nothing},
324
- x,
325
- ) where {F,Sy,B}
326
- derivs = jacobian(reverse_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
327
- jac_tensor = only(derivs)
328
- return maybe_reshape(jac_tensor, prod(Sy), length(x))
329
- end
330
-
331
- function DI.value_and_jacobian(
332
- f::F,
333
- ::EnzymeReverseOneArgJacobianPrep{Sy,B},
334
- backend::AutoEnzyme{<:ReverseMode,Nothing},
335
- x,
336
- ) where {F,Sy,B}
337
- (; derivs, val) = jacobian(
338
- reverse_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
339
- )
340
- jac_tensor = only(derivs)
341
- return val, maybe_reshape(jac_tensor, prod(Sy), length(x))
342
- end
343
-
344
- function DI.jacobian!(
345
- f::F,
346
- jac,
347
- prep::EnzymeReverseOneArgJacobianPrep,
348
- backend::AutoEnzyme{<:ReverseMode,Nothing},
349
- x,
350
- ) where {F}
351
- return copyto!(jac, DI.jacobian(f, prep, backend, x))
352
- end
353
-
354
- function DI.value_and_jacobian!(
355
- f::F,
356
- jac,
357
- prep::EnzymeReverseOneArgJacobianPrep,
358
- backend::AutoEnzyme{<:ReverseMode,Nothing},
359
- x,
360
- ) where {F}
361
- y, new_jac = DI.value_and_jacobian(f, prep, backend, x)
362
- return y, copyto!(jac, new_jac)
363
- end
364
- =#
0 commit comments