11# # Pushforward
22
33function DI. prepare_pushforward (
4+ strict:: Val ,
45 f:: F ,
5- :: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
6+ backend :: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
67 x,
78 tx:: NTuple ,
8- contexts:: Vararg{DI.Context,C} ,
9+ contexts:: Vararg{DI.Context,C} ;
910) where {F,C}
10- return DI. NoPushforwardPrep ()
11+ _sig = DI. signature (f, backend, x, tx, contexts... ; strict)
12+ return DI. NoPushforwardPrep (_sig)
1113end
1214
1315function DI. value_and_pushforward (
1416 f:: F ,
15- :: DI.NoPushforwardPrep ,
17+ prep :: DI.NoPushforwardPrep ,
1618 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
1719 x,
1820 tx:: NTuple{1} ,
1921 contexts:: Vararg{DI.Context,C} ,
2022) where {F,C}
23+ DI. check_prep (f, prep, backend, x, tx, contexts... )
2124 mode = forward_withprimal (backend)
2225 f_and_df = get_f_and_df (f, backend, mode)
2326 dx = only (tx)
2932
3033function DI. value_and_pushforward (
3134 f:: F ,
32- :: DI.NoPushforwardPrep ,
35+ prep :: DI.NoPushforwardPrep ,
3336 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
3437 x,
3538 tx:: NTuple{B} ,
3639 contexts:: Vararg{DI.Context,C} ,
3740) where {F,B,C}
41+ DI. check_prep (f, prep, backend, x, tx, contexts... )
3842 mode = forward_withprimal (backend)
3943 f_and_df = get_f_and_df (f, backend, mode, Val (B))
4044 x_and_tx = BatchDuplicated (x, tx)
4549
4650function DI. pushforward (
4751 f:: F ,
48- :: DI.NoPushforwardPrep ,
52+ prep :: DI.NoPushforwardPrep ,
4953 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
5054 x,
5155 tx:: NTuple{1} ,
5256 contexts:: Vararg{DI.Context,C} ,
5357) where {F,C}
58+ DI. check_prep (f, prep, backend, x, tx, contexts... )
5459 mode = forward_noprimal (backend)
5560 f_and_df = get_f_and_df (f, backend, mode)
5661 dx = only (tx)
6267
6368function DI. pushforward (
6469 f:: F ,
65- :: DI.NoPushforwardPrep ,
70+ prep :: DI.NoPushforwardPrep ,
6671 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
6772 x,
6873 tx:: NTuple{B} ,
6974 contexts:: Vararg{DI.Context,C} ,
7075) where {F,B,C}
76+ DI. check_prep (f, prep, backend, x, tx, contexts... )
7177 mode = forward_noprimal (backend)
7278 f_and_df = get_f_and_df (f, backend, mode, Val (B))
7379 x_and_tx = BatchDuplicated (x, tx)
@@ -85,6 +91,7 @@ function DI.value_and_pushforward!(
8591 tx:: NTuple ,
8692 contexts:: Vararg{DI.Context,C} ,
8793) where {F,C}
94+ DI. check_prep (f, prep, backend, x, tx, contexts... )
8895 # dy cannot be passed anyway
8996 y, new_ty = DI. value_and_pushforward (f, prep, backend, x, tx, contexts... )
9097 foreach (copyto!, ty, new_ty)
@@ -100,6 +107,7 @@ function DI.pushforward!(
100107 tx:: NTuple ,
101108 contexts:: Vararg{DI.Context,C} ,
102109) where {F,C}
110+ DI. check_prep (f, prep, backend, x, tx, contexts... )
103111 # dy cannot be passed anyway
104112 new_ty = DI. pushforward (f, prep, backend, x, tx, contexts... )
105113 foreach (copyto!, ty, new_ty)
@@ -108,32 +116,33 @@ end
108116
109117# # Gradient
110118
111- struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
119+ struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
120+ _sig:: Val{SIG}
121+ _valB:: Val{B}
112122 shadows:: O
113123end
114124
115- function EnzymeForwardGradientPrep (:: Val{B} , shadows:: O ) where {B,O}
116- return EnzymeForwardGradientPrep {B,O} (shadows)
117- end
118-
119125function DI. prepare_gradient (
126+ strict:: Val ,
120127 f:: F ,
121128 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
122129 x,
123- contexts:: Vararg{DI.Constant,C} ,
130+ contexts:: Vararg{DI.Constant,C} ;
124131) where {F,C}
132+ _sig = DI. signature (f, backend, x, contexts... ; strict)
125133 valB = to_val (DI. pick_batchsize (backend, x))
126134 shadows = create_shadows (valB, x)
127- return EnzymeForwardGradientPrep (valB, shadows)
135+ return EnzymeForwardGradientPrep (_sig, valB, shadows)
128136end
129137
130138function DI. gradient (
131139 f:: F ,
132- prep:: EnzymeForwardGradientPrep{B} ,
140+ prep:: EnzymeForwardGradientPrep{SIG, B} ,
133141 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
134142 x,
135143 contexts:: Vararg{DI.Constant,C} ,
136- ) where {F,B,C}
144+ ) where {F,SIG,B,C}
145+ DI. check_prep (f, prep, backend, x, contexts... )
137146 mode = forward_noprimal (backend)
138147 f_and_df = get_f_and_df (f, backend, mode)
139148 annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -145,11 +154,12 @@ end
145154
146155function DI. value_and_gradient (
147156 f:: F ,
148- prep:: EnzymeForwardGradientPrep{B} ,
157+ prep:: EnzymeForwardGradientPrep{SIG, B} ,
149158 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
150159 x,
151160 contexts:: Vararg{DI.Constant,C} ,
152- ) where {F,B,C}
161+ ) where {F,SIG,B,C}
162+ DI. check_prep (f, prep, backend, x, contexts... )
153163 mode = forward_withprimal (backend)
154164 f_and_df = get_f_and_df (f, backend, mode)
155165 annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -162,58 +172,59 @@ end
162172function DI. gradient! (
163173 f:: F ,
164174 grad,
165- prep:: EnzymeForwardGradientPrep{B} ,
175+ prep:: EnzymeForwardGradientPrep{SIG, B} ,
166176 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
167177 x,
168178 contexts:: Vararg{DI.Constant,C} ,
169- ) where {F,B,C}
179+ ) where {F,SIG,B,C}
180+ DI. check_prep (f, prep, backend, x, contexts... )
170181 return copyto! (grad, DI. gradient (f, prep, backend, x, contexts... ))
171182end
172183
173184function DI. value_and_gradient! (
174185 f:: F ,
175186 grad,
176- prep:: EnzymeForwardGradientPrep{B} ,
187+ prep:: EnzymeForwardGradientPrep{SIG, B} ,
177188 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
178189 x,
179190 contexts:: Vararg{DI.Constant,C} ,
180- ) where {F,B,C}
191+ ) where {F,SIG,B,C}
192+ DI. check_prep (f, prep, backend, x, contexts... )
181193 y, new_grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
182194 return y, copyto! (grad, new_grad)
183195end
184196
185197# # Jacobian
186198
187- struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
199+ struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
200+ _sig:: Val{SIG}
201+ _valB:: Val{B}
188202 shadows:: O
189203 output_length:: Int
190204end
191205
192- function EnzymeForwardOneArgJacobianPrep (
193- :: Val{B} , shadows:: O , output_length:: Integer
194- ) where {B,O}
195- return EnzymeForwardOneArgJacobianPrep {B,O} (shadows, output_length)
196- end
197-
198206function DI. prepare_jacobian (
207+ strict:: Val ,
199208 f:: F ,
200209 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
201210 x,
202- contexts:: Vararg{DI.Constant,C} ,
211+ contexts:: Vararg{DI.Constant,C} ;
203212) where {F,C}
213+ _sig = DI. signature (f, backend, x, contexts... ; strict)
204214 y = f (x, map (DI. unwrap, contexts)... )
205215 valB = to_val (DI. pick_batchsize (backend, x))
206216 shadows = create_shadows (valB, x)
207- return EnzymeForwardOneArgJacobianPrep (valB, shadows, length (y))
217+ return EnzymeForwardOneArgJacobianPrep (_sig, valB, shadows, length (y))
208218end
209219
210220function DI. jacobian (
211221 f:: F ,
212- prep:: EnzymeForwardOneArgJacobianPrep{B} ,
222+ prep:: EnzymeForwardOneArgJacobianPrep{SIG, B} ,
213223 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
214224 x,
215225 contexts:: Vararg{DI.Constant,C} ,
216- ) where {F,B,C}
226+ ) where {F,SIG,B,C}
227+ DI. check_prep (f, prep, backend, x, contexts... )
217228 mode = forward_noprimal (backend)
218229 f_and_df = get_f_and_df (f, backend, mode)
219230 annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -226,11 +237,12 @@ end
226237
227238function DI. value_and_jacobian (
228239 f:: F ,
229- prep:: EnzymeForwardOneArgJacobianPrep{B} ,
240+ prep:: EnzymeForwardOneArgJacobianPrep{SIG, B} ,
230241 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
231242 x,
232243 contexts:: Vararg{DI.Constant,C} ,
233- ) where {F,B,C}
244+ ) where {F,SIG,B,C}
245+ DI. check_prep (f, prep, backend, x, contexts... )
234246 mode = forward_withprimal (backend)
235247 f_and_df = get_f_and_df (f, backend, mode)
236248 annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -249,6 +261,7 @@ function DI.jacobian!(
249261 x,
250262 contexts:: Vararg{DI.Constant,C} ,
251263) where {F,C}
264+ DI. check_prep (f, prep, backend, x, contexts... )
252265 return copyto! (jac, DI. jacobian (f, prep, backend, x, contexts... ))
253266end
254267
@@ -260,6 +273,7 @@ function DI.value_and_jacobian!(
260273 x,
261274 contexts:: Vararg{DI.Constant,C} ,
262275) where {F,C}
276+ DI. check_prep (f, prep, backend, x, contexts... )
263277 y, new_jac = DI. value_and_jacobian (f, prep, backend, x, contexts... )
264278 return y, copyto! (jac, new_jac)
265279end
0 commit comments