@@ -14,6 +14,19 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
1414 hvp, jacobian, Constant
1515using ADTypes, SciMLBase
1616
17+ function instantiate_function (
18+ f:: OptimizationFunction{true} , x, :: ADTypes.AutoSparse{<:ADTypes.AutoSymbolics} ,
19+ args... ; kwargs... )
20+ instantiate_function (f, x, ADTypes. AutoSymbolics (), args... ; kwargs... )
21+ end
22+ function instantiate_function (
23+ f:: OptimizationFunction{true} , cache:: OptimizationBase.ReInitCache ,
24+ :: ADTypes.AutoSparse{<:ADTypes.AutoSymbolics} , args... ; kwargs... )
25+ x = cache. u0
26+ p = cache. p
27+
28+ return instantiate_function (f, x, ADTypes. AutoSymbolics (), p, args... ; kwargs... )
29+ end
1730function instantiate_function (
1831 f:: OptimizationFunction{true} , x, adtype:: ADTypes.AbstractADType ,
1932 p = SciMLBase. NullParameters (), num_cons = 0 ;
@@ -180,8 +193,16 @@ function instantiate_function(
180193
181194 # Prepare constraint Hessian preparations if needed by lag_h or cons_h
182195 if f. cons != = nothing && f. cons_h === nothing && (cons_h == true || lag_h == true )
183- prep_cons_hess = [prepare_hessian (cons_oop, soadtype, x, Constant (i))
184- for i in 1 : num_cons]
196+ # This is necessary because DI will create a symbolic index for `Constant(i)`
197+ # to trace into the function, since it assumes `Constant` can change between
198+ # DI calls.
199+ if adtype isa ADTypes. AutoSymbolics
200+ prep_cons_hess = [prepare_hessian (Base. Fix2 (cons_oop, i), soadtype, x)
201+ for i in 1 : num_cons]
202+ else
203+ prep_cons_hess = [prepare_hessian (cons_oop, soadtype, x, Constant (i))
204+ for i in 1 : num_cons]
205+ end
185206 else
186207 prep_cons_hess = nothing
187208 end
@@ -190,9 +211,17 @@ function instantiate_function(
190211 if f. cons != = nothing && f. cons_h === nothing && prep_cons_hess != = nothing
191212 # Standard cons_h! that returns array of matrices
192213 if cons_h == true
193- cons_h! = function (H, θ)
194- for i in 1 : num_cons
195- hessian! (cons_oop, H[i], prep_cons_hess[i], soadtype, θ, Constant (i))
214+ if adtype isa ADTypes. AutoSymbolics
215+ cons_h! = function (H, θ)
216+ for i in 1 : num_cons
217+ hessian! (Base. Fix2 (cons_oop, i), H[i], prep_cons_hess[i], soadtype, θ)
218+ end
219+ end
220+ else
221+ cons_h! = function (H, θ)
222+ for i in 1 : num_cons
223+ hessian! (cons_oop, H[i], prep_cons_hess[i], soadtype, θ, Constant (i))
224+ end
196225 end
197226 end
198227 else
0 commit comments