@@ -117,7 +117,7 @@ function _prepare_hvp_aux(
117117 rewrap = Rewrap (contexts... )
118118 # Outer pushforward
119119 new_contexts = (
120- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
120+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
121121 )
122122 outer_pushforward_prep = prepare_pushforward_nokwarg (
123123 strict, shuffled_gradient, outer (backend), x, tx, new_contexts...
@@ -161,15 +161,15 @@ function _prepare_hvp_aux(
161161 # Outer pushforward
162162 new_contexts = (
163163 FunctionContext (f),
164- PrepContext (inner_gradient_prep),
165- BackendContext (inner (backend)),
164+ ConstantOrCache (inner_gradient_prep),
165+ Constant (inner (backend)),
166166 Constant (rewrap),
167167 contexts... ,
168168 )
169169 new_contexts_in = (
170170 FunctionContext (f),
171- PrepContext (inner_gradient_in_prep),
172- BackendContext (inner (backend)),
171+ ConstantOrCache (inner_gradient_in_prep),
172+ Constant (inner (backend)),
173173 Constant (rewrap),
174174 contexts... ,
175175 )
@@ -228,15 +228,15 @@ function _prepare_hvp_aux(
228228 # Outer pushforward
229229 new_contexts = (
230230 FunctionContext (f),
231- PrepContext (inner_gradient_prep),
232- BackendContext (inner (backend)),
231+ ConstantOrCache (inner_gradient_prep),
232+ Constant (inner (backend)),
233233 Constant (rewrap),
234234 contexts... ,
235235 )
236236 new_contexts_in = (
237237 FunctionContext (f),
238- PrepContext (inner_gradient_in_prep),
239- BackendContext (inner (backend)),
238+ ConstantOrCache (inner_gradient_in_prep),
239+ Constant (inner (backend)),
240240 Constant (rewrap),
241241 contexts... ,
242242 )
@@ -279,8 +279,8 @@ function hvp(
279279 rewrap = Rewrap (contexts... )
280280 new_contexts = (
281281 FunctionContext (f),
282- map (PrepContext , maybe_inner_gradient_prep)... ,
283- BackendContext (inner (backend)),
282+ map (ConstantOrCache , maybe_inner_gradient_prep)... ,
283+ Constant (inner (backend)),
284284 Constant (rewrap),
285285 contexts... ,
286286 )
@@ -318,8 +318,8 @@ function _hvp_aux!(
318318 rewrap = Rewrap (contexts... )
319319 new_contexts = (
320320 FunctionContext (f),
321- map (PrepContext , maybe_inner_gradient_in_prep)... ,
322- BackendContext (inner (backend)),
321+ map (ConstantOrCache , maybe_inner_gradient_in_prep)... ,
322+ Constant (inner (backend)),
323323 Constant (rewrap),
324324 contexts... ,
325325 )
@@ -349,8 +349,8 @@ function _hvp_aux!(
349349 rewrap = Rewrap (contexts... )
350350 new_contexts = (
351351 FunctionContext (f),
352- map (PrepContext , maybe_inner_gradient_prep)... ,
353- BackendContext (inner (backend)),
352+ map (ConstantOrCache , maybe_inner_gradient_prep)... ,
353+ Constant (inner (backend)),
354354 Constant (rewrap),
355355 contexts... ,
356356 )
@@ -378,8 +378,8 @@ function gradient_and_hvp(
378378 rewrap = Rewrap (contexts... )
379379 new_contexts = (
380380 FunctionContext (f),
381- map (PrepContext , maybe_inner_gradient_prep)... ,
382- BackendContext (inner (backend)),
381+ map (ConstantOrCache , maybe_inner_gradient_prep)... ,
382+ Constant (inner (backend)),
383383 Constant (rewrap),
384384 contexts... ,
385385 )
@@ -419,8 +419,8 @@ function _gradient_and_hvp_aux!(
419419 rewrap = Rewrap (contexts... )
420420 new_contexts = (
421421 FunctionContext (f),
422- map (PrepContext , maybe_inner_gradient_in_prep)... ,
423- BackendContext (inner (backend)),
422+ map (ConstantOrCache , maybe_inner_gradient_in_prep)... ,
423+ Constant (inner (backend)),
424424 Constant (rewrap),
425425 contexts... ,
426426 )
@@ -452,8 +452,8 @@ function _gradient_and_hvp_aux!(
452452 rewrap = Rewrap (contexts... )
453453 new_contexts = (
454454 FunctionContext (f),
455- map (PrepContext , maybe_inner_gradient_prep)... ,
456- BackendContext (inner (backend)),
455+ map (ConstantOrCache , maybe_inner_gradient_prep)... ,
456+ Constant (inner (backend)),
457457 Constant (rewrap),
458458 contexts... ,
459459 )
@@ -492,7 +492,7 @@ function _prepare_hvp_aux(
492492 rewrap = Rewrap (contexts... )
493493 new_contexts = (
494494 FunctionContext (f),
495- BackendContext (inner (backend)),
495+ Constant (inner (backend)),
496496 Constant (first (tx)),
497497 Constant (rewrap),
498498 contexts... ,
@@ -522,7 +522,7 @@ function hvp(
522522 outer (backend),
523523 x,
524524 FunctionContext (f),
525- BackendContext (inner (backend)),
525+ Constant (inner (backend)),
526526 Constant (dx),
527527 Constant (rewrap),
528528 contexts... ,
@@ -551,7 +551,7 @@ function hvp!(
551551 outer (backend),
552552 x,
553553 FunctionContext (f),
554- BackendContext (inner (backend)),
554+ Constant (inner (backend)),
555555 Constant (tx[b]),
556556 Constant (rewrap),
557557 contexts... ,
@@ -613,7 +613,7 @@ function _prepare_hvp_aux(
613613 _sig = signature (f, backend, x, tx, contexts... ; strict)
614614 rewrap = Rewrap (contexts... )
615615 new_contexts = (
616- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
616+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
617617 )
618618 grad_buffer = similar (x)
619619 outer_pullback_prep = prepare_pullback_nokwarg (
@@ -649,7 +649,7 @@ function hvp(
649649 (; outer_pullback_prep) = prep
650650 rewrap = Rewrap (contexts... )
651651 new_contexts = (
652- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
652+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
653653 )
654654 return pullback (
655655 shuffled_gradient, outer_pullback_prep, outer (backend), x, tx, new_contexts...
@@ -684,7 +684,7 @@ function _hvp_aux!(
684684 (; grad_buffer, outer_pullback_in_prep) = prep
685685 rewrap = Rewrap (contexts... )
686686 new_contexts = (
687- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
687+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
688688 )
689689 return pullback! (
690690 shuffled_gradient!,
@@ -711,7 +711,7 @@ function _hvp_aux!(
711711 (; outer_pullback_prep) = prep
712712 rewrap = Rewrap (contexts... )
713713 new_contexts = (
714- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
714+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
715715 )
716716 return pullback! (
717717 shuffled_gradient, tg, outer_pullback_prep, outer (backend), x, tx, new_contexts...
@@ -730,7 +730,7 @@ function gradient_and_hvp(
730730 (; outer_pullback_prep) = prep
731731 rewrap = Rewrap (contexts... )
732732 new_contexts = (
733- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
733+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
734734 )
735735 return value_and_pullback (
736736 shuffled_gradient, outer_pullback_prep, outer (backend), x, tx, new_contexts...
@@ -767,7 +767,7 @@ function _gradient_and_hvp_aux!(
767767 (; outer_pullback_in_prep) = prep
768768 rewrap = Rewrap (contexts... )
769769 new_contexts = (
770- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
770+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
771771 )
772772 new_grad, _ = value_and_pullback! (
773773 shuffled_gradient!,
@@ -796,7 +796,7 @@ function _gradient_and_hvp_aux!(
796796 (; outer_pullback_prep) = prep
797797 rewrap = Rewrap (contexts... )
798798 new_contexts = (
799- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
799+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
800800 )
801801 new_grad, _ = value_and_pullback! (
802802 shuffled_gradient, tg, outer_pullback_prep, outer (backend), x, tx, new_contexts...
0 commit comments