@@ -138,12 +138,14 @@ struct PushforwardJacobianPrep{
138138 BS<: BatchSizeSettings ,
139139 S<: AbstractVector{<:NTuple} ,
140140 R<: AbstractVector{<:NTuple} ,
141+ SE<: NTuple ,
141142 E<: PushforwardPrep ,
142143} <: StandardJacobianPrep{SIG}
143144 _sig:: Val{SIG}
144145 batch_size_settings:: BS
145146 batched_seeds:: S
146147 batched_results:: R
148+ seed_example:: SE
147149 pushforward_prep:: E
148150end
149151
@@ -152,12 +154,14 @@ struct PullbackJacobianPrep{
152154 BS<: BatchSizeSettings ,
153155 S<: AbstractVector{<:NTuple} ,
154156 R<: AbstractVector{<:NTuple} ,
157+ SE<: NTuple ,
155158 E<: PullbackPrep ,
156159} <: StandardJacobianPrep{SIG}
157160 _sig:: Val{SIG}
158161 batch_size_settings:: BS
159162 batched_seeds:: S
160163 batched_results:: R
164+ seed_example:: SE
161165 pullback_prep:: E
162166end
163167
@@ -211,11 +215,17 @@ function _prepare_jacobian_aux(
211215 ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
212216 ]
213217 batched_results = [ntuple (b -> similar (y), Val (B)) for _ in batched_seeds]
218+ seed_example = ntuple (b -> zero (x), Val (B))
214219 pushforward_prep = prepare_pushforward_nokwarg (
215- strict, f_or_f!y... , backend, x, batched_seeds[ 1 ] , contexts...
220+ strict, f_or_f!y... , backend, x, seed_example , contexts...
216221 )
217222 return PushforwardJacobianPrep (
218- _sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep
223+ _sig,
224+ batch_size_settings,
225+ batched_seeds,
226+ batched_results,
227+ seed_example,
228+ pushforward_prep,
219229 )
220230end
221231
@@ -236,11 +246,17 @@ function _prepare_jacobian_aux(
236246 ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
237247 ]
238248 batched_results = [ntuple (b -> similar (x), Val (B)) for _ in batched_seeds]
249+ seed_example = ntuple (b -> zero (y), Val (B))
239250 pullback_prep = prepare_pullback_nokwarg (
240- strict, f_or_f!y... , backend, x, batched_seeds[ 1 ] , contexts...
251+ strict, f_or_f!y... , backend, x, seed_example , contexts...
241252 )
242253 return PullbackJacobianPrep (
243- _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep
254+ _sig,
255+ batch_size_settings,
256+ batched_seeds,
257+ batched_results,
258+ seed_example,
259+ pullback_prep,
244260 )
245261end
246262
@@ -363,11 +379,11 @@ function _jacobian_aux(
363379 x,
364380 contexts:: Vararg{Context,C} ,
365381) where {FY,SIG,B,aligned,C}
366- (; batch_size_settings, batched_seeds, pushforward_prep) = prep
382+ (; batch_size_settings, batched_seeds, seed_example, pushforward_prep) = prep
367383 (; A, B_last) = batch_size_settings
368384
369385 pushforward_prep_same = prepare_pushforward_same_point (
370- f_or_f!y... , pushforward_prep, backend, x, batched_seeds[ 1 ] , contexts...
386+ f_or_f!y... , pushforward_prep, backend, x, seed_example , contexts...
371387 )
372388
373389 jac = mapreduce (hcat, eachindex (batched_seeds)) do a
@@ -419,11 +435,11 @@ function _jacobian_aux(
419435 x,
420436 contexts:: Vararg{Context,C} ,
421437) where {FY,SIG,B,aligned,C}
422- (; batch_size_settings, batched_seeds, pullback_prep) = prep
438+ (; batch_size_settings, batched_seeds, seed_example, pullback_prep) = prep
423439 (; A, B_last) = batch_size_settings
424440
425441 pullback_prep_same = prepare_pullback_same_point (
426- f_or_f!y... , prep . pullback_prep, backend, x, batched_seeds[ 1 ] , contexts...
442+ f_or_f!y... , pullback_prep, backend, x, seed_example , contexts...
427443 )
428444
429445 jac = mapreduce (vcat, eachindex (batched_seeds)) do a
@@ -451,11 +467,13 @@ function _jacobian_aux!(
451467 x,
452468 contexts:: Vararg{Context,C} ,
453469) where {FY,SIG,B,C}
454- (; batch_size_settings, batched_seeds, batched_results, pushforward_prep) = prep
470+ (;
471+ batch_size_settings, batched_seeds, batched_results, seed_example, pushforward_prep
472+ ) = prep
455473 (; N) = batch_size_settings
456474
457475 pushforward_prep_same = prepare_pushforward_same_point (
458- f_or_f!y... , pushforward_prep, backend, x, batched_seeds[ 1 ] , contexts...
476+ f_or_f!y... , pushforward_prep, backend, x, seed_example , contexts...
459477 )
460478
461479 for a in eachindex (batched_seeds, batched_results)
@@ -487,11 +505,12 @@ function _jacobian_aux!(
487505 x,
488506 contexts:: Vararg{Context,C} ,
489507) where {FY,SIG,B,C}
490- (; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep
508+ (; batch_size_settings, batched_seeds, batched_results, seed_example, pullback_prep) =
509+ prep
491510 (; N) = batch_size_settings
492511
493512 pullback_prep_same = prepare_pullback_same_point (
494- f_or_f!y... , pullback_prep, backend, x, batched_seeds[ 1 ] , contexts...
513+ f_or_f!y... , pullback_prep, backend, x, seed_example , contexts...
495514 )
496515
497516 for a in eachindex (batched_seeds, batched_results)
0 commit comments