@@ -12,7 +12,7 @@ function DI.prepare_pushforward_nokwarg(
1212 strict:: Val , f, backend:: AutoFiniteDiff , x, tx:: NTuple , contexts:: Vararg{DI.Context,C} ;
1313) where {C}
1414 _sig = DI. signature (f, backend, x, tx, contexts... ; strict)
15- fc = DI. with_contexts (f, contexts... )
15+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
1616 y = fc (x)
1717 cache = if x isa Number || y isa Number
1818 nothing
@@ -89,7 +89,7 @@ function DI.pushforward(
8989) where {SIG,C}
9090 DI. check_prep (f, prep, backend, x, tx, contexts... )
9191 (; relstep, absstep, dir) = prep
92- fc = DI. with_contexts (f, contexts... )
92+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
9393 ty = map (tx) do dx
9494 finite_difference_jvp (fc, x, dx, prep. cache; relstep, absstep, dir)
9595 end
@@ -106,7 +106,7 @@ function DI.value_and_pushforward(
106106) where {SIG,C}
107107 DI. check_prep (f, prep, backend, x, tx, contexts... )
108108 (; relstep, absstep, dir) = prep
109- fc = DI. with_contexts (f, contexts... )
109+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
110110 y = fc (x)
111111 ty = map (tx) do dx
112112 finite_difference_jvp (fc, x, dx, prep. cache, y; relstep, absstep, dir)
@@ -128,7 +128,7 @@ function DI.prepare_derivative_nokwarg(
128128 strict:: Val , f, backend:: AutoFiniteDiff , x, contexts:: Vararg{DI.Context,C}
129129) where {C}
130130 _sig = DI. signature (f, backend, x, contexts... ; strict)
131- fc = DI. with_contexts (f, contexts... )
131+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
132132 y = fc (x)
133133 cache = if y isa Number
134134 nothing
@@ -161,7 +161,7 @@ function DI.derivative(
161161) where {SIG,C}
162162 DI. check_prep (f, prep, backend, x, contexts... )
163163 (; relstep, absstep, dir) = prep
164- fc = DI. with_contexts (f, contexts... )
164+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
165165 return finite_difference_derivative (fc, x, fdtype (backend); relstep, absstep, dir)
166166end
167167
@@ -174,7 +174,7 @@ function DI.value_and_derivative(
174174) where {SIG,C}
175175 DI. check_prep (f, prep, backend, x, contexts... )
176176 (; relstep, absstep, dir) = prep
177- fc = DI. with_contexts (f, contexts... )
177+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
178178 y = fc (x)
179179 return (
180180 y,
@@ -195,7 +195,7 @@ function DI.derivative(
195195) where {SIG,C}
196196 DI. check_prep (f, prep, backend, x, contexts... )
197197 (; relstep, absstep, dir) = prep
198- fc = DI. with_contexts (f, contexts... )
198+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
199199 return finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir)
200200end
201201
@@ -209,7 +209,7 @@ function DI.derivative!(
209209) where {SIG,C}
210210 DI. check_prep (f, prep, backend, x, contexts... )
211211 (; relstep, absstep, dir) = prep
212- fc = DI. with_contexts (f, contexts... )
212+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
213213 return finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep, dir)
214214end
215215
@@ -221,7 +221,7 @@ function DI.value_and_derivative(
221221 contexts:: Vararg{DI.Context,C} ,
222222) where {SIG,C}
223223 DI. check_prep (f, prep, backend, x, contexts... )
224- fc = DI. with_contexts (f, contexts... )
224+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
225225 (; relstep, absstep, dir) = prep
226226 y = fc (x)
227227 return (y, finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir))
@@ -237,7 +237,7 @@ function DI.value_and_derivative!(
237237) where {SIG,C}
238238 DI. check_prep (f, prep, backend, x, contexts... )
239239 (; relstep, absstep, dir) = prep
240- fc = DI. with_contexts (f, contexts... )
240+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
241241 return (
242242 fc (x), finite_difference_gradient! (der, fc, x, prep. cache; relstep, absstep, dir)
243243 )
@@ -257,7 +257,7 @@ function DI.prepare_gradient_nokwarg(
257257 strict:: Val , f, backend:: AutoFiniteDiff , x, contexts:: Vararg{DI.Context,C}
258258) where {C}
259259 _sig = DI. signature (f, backend, x, contexts... ; strict)
260- fc = DI. with_contexts (f, contexts... )
260+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
261261 y = fc (x)
262262 df = zero (y) .* x
263263 cache = GradientCache (df, x, fdtype (backend))
@@ -284,7 +284,7 @@ function DI.gradient(
284284) where {C}
285285 DI. check_prep (f, prep, backend, x, contexts... )
286286 (; relstep, absstep, dir) = prep
287- fc = DI. with_contexts (f, contexts... )
287+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
288288 return finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir)
289289end
290290
@@ -297,7 +297,7 @@ function DI.value_and_gradient(
297297) where {C}
298298 DI. check_prep (f, prep, backend, x, contexts... )
299299 (; relstep, absstep, dir) = prep
300- fc = DI. with_contexts (f, contexts... )
300+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
301301 return fc (x), finite_difference_gradient (fc, x, prep. cache; relstep, absstep, dir)
302302end
303303
@@ -311,7 +311,7 @@ function DI.gradient!(
311311) where {C}
312312 DI. check_prep (f, prep, backend, x, contexts... )
313313 (; relstep, absstep, dir) = prep
314- fc = DI. with_contexts (f, contexts... )
314+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
315315 return finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep, dir)
316316end
317317
@@ -325,7 +325,7 @@ function DI.value_and_gradient!(
325325) where {C}
326326 DI. check_prep (f, prep, backend, x, contexts... )
327327 (; relstep, absstep, dir) = prep
328- fc = DI. with_contexts (f, contexts... )
328+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
329329 return (
330330 fc (x), finite_difference_gradient! (grad, fc, x, prep. cache; relstep, absstep, dir)
331331 )
@@ -345,7 +345,7 @@ function DI.prepare_jacobian_nokwarg(
345345 strict:: Val , f, backend:: AutoFiniteDiff , x, contexts:: Vararg{DI.Context,C}
346346) where {C}
347347 _sig = DI. signature (f, backend, x, contexts... ; strict)
348- fc = DI. with_contexts (f, contexts... )
348+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
349349 y = fc (x)
350350 x1 = similar (x)
351351 fx = similar (y)
@@ -374,7 +374,7 @@ function DI.jacobian(
374374) where {C}
375375 DI. check_prep (f, prep, backend, x, contexts... )
376376 (; relstep, absstep, dir) = prep
377- fc = DI. with_contexts (f, contexts... )
377+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
378378 return finite_difference_jacobian (fc, x, prep. cache; relstep, absstep, dir)
379379end
380380
@@ -386,7 +386,7 @@ function DI.value_and_jacobian(
386386 contexts:: Vararg{DI.Context,C} ,
387387) where {C}
388388 DI. check_prep (f, prep, backend, x, contexts... )
389- fc = DI. with_contexts (f, contexts... )
389+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
390390 (; relstep, absstep, dir) = prep
391391 y = fc (x)
392392 return (y, finite_difference_jacobian (fc, x, prep. cache, y; relstep, absstep, dir))
@@ -402,7 +402,7 @@ function DI.jacobian!(
402402) where {C}
403403 DI. check_prep (f, prep, backend, x, contexts... )
404404 (; relstep, absstep, dir) = prep
405- fc = DI. with_contexts (f, contexts... )
405+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
406406 return copyto! (
407407 jac,
408408 finite_difference_jacobian (
@@ -421,7 +421,7 @@ function DI.value_and_jacobian!(
421421) where {C}
422422 DI. check_prep (f, prep, backend, x, contexts... )
423423 (; relstep, absstep, dir) = prep
424- fc = DI. with_contexts (f, contexts... )
424+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
425425 y = fc (x)
426426 return (
427427 y,
@@ -450,7 +450,7 @@ function DI.prepare_hessian_nokwarg(
450450 strict:: Val , f, backend:: AutoFiniteDiff , x, contexts:: Vararg{DI.Context,C}
451451) where {C}
452452 _sig = DI. signature (f, backend, x, contexts... ; strict)
453- fc = DI. with_contexts (f, contexts... )
453+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
454454 y = fc (x)
455455 df = zero (y) .* x
456456 gradient_cache = GradientCache (df, x, fdtype (backend))
@@ -481,7 +481,7 @@ function DI.hessian(
481481) where {C}
482482 DI. check_prep (f, prep, backend, x, contexts... )
483483 (; relstep_h, absstep_h) = prep
484- fc = DI. with_contexts (f, contexts... )
484+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
485485 return finite_difference_hessian (
486486 fc, x, prep. hessian_cache; relstep= relstep_h, absstep= absstep_h
487487 )
@@ -497,7 +497,7 @@ function DI.hessian!(
497497) where {C}
498498 DI. check_prep (f, prep, backend, x, contexts... )
499499 (; relstep_h, absstep_h) = prep
500- fc = DI. with_contexts (f, contexts... )
500+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
501501 return finite_difference_hessian! (
502502 hess, fc, x, prep. hessian_cache; relstep= relstep_h, absstep= absstep_h
503503 )
@@ -512,7 +512,7 @@ function DI.value_gradient_and_hessian(
512512) where {C}
513513 DI. check_prep (f, prep, backend, x, contexts... )
514514 (; relstep_g, absstep_g, relstep_h, absstep_h) = prep
515- fc = DI. with_contexts (f, contexts... )
515+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
516516 grad = finite_difference_gradient (
517517 fc, x, prep. gradient_cache; relstep= relstep_g, absstep= absstep_g
518518 )
@@ -533,7 +533,7 @@ function DI.value_gradient_and_hessian!(
533533) where {C}
534534 DI. check_prep (f, prep, backend, x, contexts... )
535535 (; relstep_g, absstep_g, relstep_h, absstep_h) = prep
536- fc = DI. with_contexts (f, contexts... )
536+ fc = DI. fix_tail (f, map (DI . unwrap, contexts) ... )
537537 finite_difference_gradient! (
538538 grad, fc, x, prep. gradient_cache; relstep= relstep_g, absstep= absstep_g
539539 )
0 commit comments