@@ -39,9 +39,9 @@ function instantiate_function(
3939 adtype, soadtype = generate_adtype (adtype)
4040
4141 if g == true && f. grad === nothing
42- extras_grad = prepare_gradient (_f, adtype, x)
42+ prep_grad = prepare_gradient (_f, adtype, x)
4343 function grad (res, θ)
44- gradient! (_f, res, adtype, θ, extras_grad )
44+ gradient! (_f, res, prep_grad, adtype, θ )
4545 end
4646 if p != = SciMLBase. NullParameters () && p != = nothing
4747 function grad (res, θ, p)
@@ -57,10 +57,10 @@ function instantiate_function(
5757
5858 if fg == true && f. fg === nothing
5959 if g == false
60- extras_grad = prepare_gradient (_f, adtype, x)
60+ prep_grad = prepare_gradient (_f, adtype, x)
6161 end
6262 function fg! (res, θ)
63- (y, _) = value_and_gradient! (_f, res, adtype, θ, extras_grad )
63+ (y, _) = value_and_gradient! (_f, res, prep_grad, adtype, θ )
6464 return y
6565 end
6666 if p != = SciMLBase. NullParameters () && p != = nothing
@@ -79,9 +79,9 @@ function instantiate_function(
7979 hess_sparsity = f. hess_prototype
8080 hess_colors = f. hess_colorvec
8181 if h == true && f. hess === nothing
82- extras_hess = prepare_hessian (_f, soadtype, x)
82+ prep_hess = prepare_hessian (_f, soadtype, x)
8383 function hess (res, θ)
84- hessian! (_f, res, soadtype, θ, extras_hess )
84+ hessian! (_f, res, prep_hess, soadtype, θ )
8585 end
8686 if p != = SciMLBase. NullParameters () && p != = nothing
8787 function hess (res, θ, p)
@@ -98,7 +98,7 @@ function instantiate_function(
9898 if fgh == true && f. fgh === nothing
9999 function fgh! (G, H, θ)
100100 (y, _, _) = value_derivative_and_second_derivative! (
101- _f, G, H, soadtype, θ, extras_hess )
101+ _f, G, H, prep_hess, soadtype, θ )
102102 return y
103103 end
104104 if p != = SciMLBase. NullParameters () && p != = nothing
@@ -116,14 +116,14 @@ function instantiate_function(
116116 end
117117
118118 if hv == true && f. hv === nothing
119- extras_hvp = prepare_hvp (_f, soadtype, x, zeros (eltype (x), size (x)))
119+ prep_hvp = prepare_hvp (_f, soadtype, x, ( zeros (eltype (x), size (x)), ))
120120 function hv! (H, θ, v)
121- hvp! (_f, H, soadtype, θ, v, extras_hvp )
121+ only ( hvp! (_f, (H,), prep_hvp, soadtype, θ, (v,)) )
122122 end
123123 if p != = SciMLBase. NullParameters () && p != = nothing
124124 function hv! (H, θ, v, p)
125125 global _p = p
126- hvp! (_f, H, soadtype, θ, v )
126+ only ( hvp! (_f, (H,), soadtype, θ, (v,)) )
127127 end
128128 end
129129 elseif hv == true
@@ -156,9 +156,9 @@ function instantiate_function(
156156 cons_jac_prototype = f. cons_jac_prototype
157157 cons_jac_colorvec = f. cons_jac_colorvec
158158 if cons != = nothing && cons_j == true && f. cons_j === nothing
159- extras_jac = prepare_jacobian (cons_oop, adtype, x)
159+ prep_jac = prepare_jacobian (cons_oop, adtype, x)
160160 function cons_j! (J, θ)
161- jacobian! (cons_oop, J, adtype, θ, extras_jac )
161+ jacobian! (cons_oop, J, prep_jac, adtype, θ )
162162 if size (J, 1 ) == 1
163163 J = vec (J)
164164 end
@@ -170,9 +170,9 @@ function instantiate_function(
170170 end
171171
172172 if f. cons_vjp === nothing && cons_vjp == true && cons != = nothing
173- extras_pullback = prepare_pullback (cons_oop, adtype, x, ones (eltype (x), num_cons))
173+ prep_pullback = prepare_pullback (cons_oop, adtype, x, ( ones (eltype (x), num_cons), ))
174174 function cons_vjp! (J, θ, v)
175- pullback! (cons_oop, J, adtype, θ, v, extras_pullback )
175+ only ( pullback! (cons_oop, (J,), prep_pullback, adtype, θ, (v,)) )
176176 end
177177 elseif cons_vjp == true && cons != = nothing
178178 cons_vjp! = (J, θ, v) -> f. cons_vjp (J, θ, v, p)
@@ -181,10 +181,10 @@ function instantiate_function(
181181 end
182182
183183 if f. cons_jvp === nothing && cons_jvp == true && cons != = nothing
184- extras_pushforward = prepare_pushforward (
185- cons_oop, adtype, x, ones (eltype (x), length (x)))
184+ prep_pushforward = prepare_pushforward (
185+ cons_oop, adtype, x, ( ones (eltype (x), length (x)), ))
186186 function cons_jvp! (J, θ, v)
187- pushforward! (cons_oop, J, adtype, θ, v, extras_pushforward )
187+ only ( pushforward! (cons_oop, (J,), prep_pushforward, adtype, θ, (v,)) )
188188 end
189189 elseif cons_jvp == true && cons != = nothing
190190 cons_jvp! = (J, θ, v) -> f. cons_jvp (J, θ, v, p)
@@ -196,11 +196,11 @@ function instantiate_function(
196196 conshess_colors = f. cons_hess_colorvec
197197 if cons != = nothing && f. cons_h === nothing && cons_h == true
198198 fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
199- extras_cons_hess = prepare_hessian .(fncs, Ref (soadtype), Ref (x))
199+ prep_cons_hess = prepare_hessian .(fncs, Ref (soadtype), Ref (x))
200200
201201 function cons_h! (H, θ)
202202 for i in 1 : num_cons
203- hessian! (fncs[i], H[i], soadtype, θ, extras_cons_hess[i] )
203+ hessian! (fncs[i], H[i], prep_cons_hess[i], soadtype, θ )
204204 end
205205 end
206206 elseif cons_h == true && cons != = nothing
@@ -212,7 +212,7 @@ function instantiate_function(
212212 lag_hess_prototype = f. lag_hess_prototype
213213
214214 if cons != = nothing && lag_h == true && f. lag_h === nothing
215- lag_extras = prepare_hessian (
215+ lag_prep = prepare_hessian (
216216 lagrangian, soadtype, vcat (x, [one (eltype (x))], ones (eltype (x), num_cons)))
217217 lag_hess_prototype = zeros (Bool, length (x) + num_cons + 1 , length (x) + num_cons + 1 )
218218
@@ -221,13 +221,13 @@ function instantiate_function(
221221 cons_h (H, θ)
222222 H *= λ
223223 else
224- H .= @view (hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )[
224+ H .= @view (hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))[
225225 1 : length (θ), 1 : length (θ)])
226226 end
227227 end
228228
229229 function lag_h! (h:: AbstractVector , θ, σ, λ)
230- H = hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )
230+ H = hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))
231231 k = 0
232232 for i in 1 : length (θ)
233233 for j in 1 : i
@@ -244,14 +244,14 @@ function instantiate_function(
244244 H *= λ
245245 else
246246 global _p = p
247- H .= @view (hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )[
247+ H .= @view (hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))[
248248 1 : length (θ), 1 : length (θ)])
249249 end
250250 end
251251
252252 function lag_h! (h:: AbstractVector , θ, σ, λ, p)
253253 global _p = p
254- H = hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )
254+ H = hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))
255255 k = 0
256256 for i in 1 : length (θ)
257257 for j in 1 : i
@@ -308,9 +308,9 @@ function instantiate_function(
308308 adtype, soadtype = generate_adtype (adtype)
309309
310310 if g == true && f. grad === nothing
311- extras_grad = prepare_gradient (_f, adtype, x)
311+ prep_grad = prepare_gradient (_f, adtype, x)
312312 function grad (θ)
313- gradient (_f, adtype, θ, extras_grad )
313+ gradient (_f, prep_grad, adtype, θ )
314314 end
315315 if p != = SciMLBase. NullParameters () && p != = nothing
316316 function grad (θ, p)
@@ -326,10 +326,10 @@ function instantiate_function(
326326
327327 if fg == true && f. fg === nothing
328328 if g == false
329- extras_grad = prepare_gradient (_f, adtype, x)
329+ prep_grad = prepare_gradient (_f, adtype, x)
330330 end
331331 function fg! (θ)
332- (y, res) = value_and_gradient (_f, adtype, θ, extras_grad )
332+ (y, res) = value_and_gradient (_f, prep_grad, adtype, θ )
333333 return y, res
334334 end
335335 if p != = SciMLBase. NullParameters () && p != = nothing
@@ -348,9 +348,9 @@ function instantiate_function(
348348 hess_sparsity = f. hess_prototype
349349 hess_colors = f. hess_colorvec
350350 if h == true && f. hess === nothing
351- extras_hess = prepare_hessian (_f, soadtype, x)
351+ prep_hess = prepare_hessian (_f, soadtype, x)
352352 function hess (θ)
353- hessian (_f, soadtype, θ, extras_hess )
353+ hessian (_f, prep_hess, soadtype, θ )
354354 end
355355 if p != = SciMLBase. NullParameters () && p != = nothing
356356 function hess (θ, p)
@@ -366,7 +366,7 @@ function instantiate_function(
366366
367367 if fgh == true && f. fgh === nothing
368368 function fgh! (θ)
369- (y, G, H) = value_derivative_and_second_derivative (_f, adtype, θ, extras_hess )
369+ (y, G, H) = value_derivative_and_second_derivative (_f, prep_hess, adtype, θ )
370370 return y, G, H
371371 end
372372 if p != = SciMLBase. NullParameters () && p != = nothing
@@ -383,14 +383,14 @@ function instantiate_function(
383383 end
384384
385385 if hv == true && f. hv === nothing
386- extras_hvp = prepare_hvp (_f, soadtype, x, zeros (eltype (x), size (x)))
386+ prep_hvp = prepare_hvp (_f, soadtype, x, ( zeros (eltype (x), size (x)), ))
387387 function hv! (θ, v)
388- hvp (_f, soadtype, θ, v, extras_hvp )
388+ only ( hvp (_f, prep_hvp, soadtype, θ, (v)) )
389389 end
390390 if p != = SciMLBase. NullParameters () && p != = nothing
391391 function hv! (θ, v, p)
392392 global _p = p
393- hvp ( _f, soadtype, θ, v, extras_hvp )
393+ only ( vp ( _f, prep_hvp, soadtype, θ, (v,)) )
394394 end
395395 end
396396 elseif hv == true
@@ -417,9 +417,9 @@ function instantiate_function(
417417 cons_jac_prototype = f. cons_jac_prototype
418418 cons_jac_colorvec = f. cons_jac_colorvec
419419 if cons != = nothing && cons_j == true && f. cons_j === nothing
420- extras_jac = prepare_jacobian (cons, adtype, x)
420+ prep_jac = prepare_jacobian (cons, adtype, x)
421421 function cons_j! (θ)
422- J = jacobian (cons, adtype, θ, extras_jac )
422+ J = jacobian (cons, prep_jac, adtype, θ )
423423 if size (J, 1 ) == 1
424424 J = vec (J)
425425 end
@@ -432,9 +432,9 @@ function instantiate_function(
432432 end
433433
434434 if f. cons_vjp === nothing && cons_vjp == true && cons != = nothing
435- extras_pullback = prepare_pullback (cons, adtype, x, ones (eltype (x), num_cons))
435+ prep_pullback = prepare_pullback (cons, adtype, x, ( ones (eltype (x), num_cons), ))
436436 function cons_vjp! (θ, v)
437- return pullback (cons, adtype, θ, v, extras_pullback )
437+ return only ( pullback (cons, prep_pullback, adtype, θ, (v,)) )
438438 end
439439 elseif cons_vjp == true && cons != = nothing
440440 cons_vjp! = (θ, v) -> f. cons_vjp (θ, v, p)
@@ -443,10 +443,10 @@ function instantiate_function(
443443 end
444444
445445 if f. cons_jvp === nothing && cons_jvp == true && cons != = nothing
446- extras_pushforward = prepare_pushforward (
447- cons, adtype, x, ones (eltype (x), length (x)))
446+ prep_pushforward = prepare_pushforward (
447+ cons, adtype, x, ( ones (eltype (x), length (x)), ))
448448 function cons_jvp! (θ, v)
449- return pushforward (cons, adtype, θ, v, extras_pushforward )
449+ return only ( pushforward (cons, prep_pushforward, adtype, θ, (v,)) )
450450 end
451451 elseif cons_jvp == true && cons != = nothing
452452 cons_jvp! = (θ, v) -> f. cons_jvp (θ, v, p)
@@ -458,11 +458,11 @@ function instantiate_function(
458458 conshess_colors = f. cons_hess_colorvec
459459 if cons != = nothing && cons_h == true && f. cons_h === nothing
460460 fncs = [(x) -> cons (x)[i] for i in 1 : num_cons]
461- extras_cons_hess = prepare_hessian .(fncs, Ref (soadtype), Ref (x))
461+ prep_cons_hess = prepare_hessian .(fncs, Ref (soadtype), Ref (x))
462462
463463 function cons_h! (θ)
464464 H = map (1 : num_cons) do i
465- hessian (fncs[i], soadtype, θ, extras_cons_hess[i] )
465+ hessian (fncs[i], prep_cons_hess[i], soadtype, θ )
466466 end
467467 return H
468468 end
@@ -475,15 +475,15 @@ function instantiate_function(
475475 lag_hess_prototype = f. lag_hess_prototype
476476
477477 if cons != = nothing && lag_h == true && f. lag_h === nothing
478- lag_extras = prepare_hessian (
478+ lag_prep = prepare_hessian (
479479 lagrangian, soadtype, vcat (x, [one (eltype (x))], ones (eltype (x), num_cons)))
480480 lag_hess_prototype = zeros (Bool, length (x) + num_cons + 1 , length (x) + num_cons + 1 )
481481
482482 function lag_h! (θ, σ, λ)
483483 if σ == zero (eltype (θ))
484484 return λ .* cons_h (θ)
485485 else
486- return hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )[
486+ return hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))[
487487 1 : length (θ), 1 : length (θ)]
488488 end
489489 end
@@ -494,7 +494,7 @@ function instantiate_function(
494494 return λ .* cons_h (θ)
495495 else
496496 global _p = p
497- return hessian (lagrangian, soadtype, vcat (θ, [σ], λ), lag_extras )[
497+ return hessian (lagrangian, lag_prep, soadtype, vcat (θ, [σ], λ))[
498498 1 : length (θ), 1 : length (θ)]
499499 end
500500 end
0 commit comments