Skip to content

Commit 41d6fab

Browse files
fix: support AutoSymbolics and AutoSymbolics in OptimizationDIExt
1 parent d3c4836 commit 41d6fab

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

lib/OptimizationBase/src/OptimizationDIExt.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
1414
hvp, jacobian, Constant
1515
using ADTypes, SciMLBase
1616

17+
function instantiate_function(
18+
f::OptimizationFunction{true}, x, ::ADTypes.AutoSparse{<:ADTypes.AutoSymbolics},
19+
args...; kwargs...)
20+
instantiate_function(f, x, ADTypes.AutoSymbolics(), args...; kwargs...)
21+
end
22+
function instantiate_function(
23+
f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
24+
::ADTypes.AutoSparse{<:ADTypes.AutoSymbolics}, args...; kwargs...)
25+
x = cache.u0
26+
p = cache.p
27+
28+
return instantiate_function(f, x, ADTypes.AutoSymbolics(), p, args...; kwargs...)
29+
end
1730
function instantiate_function(
1831
f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType,
1932
p = SciMLBase.NullParameters(), num_cons = 0;
@@ -180,8 +193,16 @@ function instantiate_function(
180193

181194
# Prepare constraint Hessian preparations if needed by lag_h or cons_h
182195
if f.cons !== nothing && f.cons_h === nothing && (cons_h == true || lag_h == true)
183-
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
184-
for i in 1:num_cons]
196+
# This is necessary because DI will create a symbolic index for `Constant(i)`
197+
# to trace into the function, since it assumes `Constant` can change between
198+
# DI calls.
199+
if adtype isa ADTypes.AutoSymbolics
200+
prep_cons_hess = [prepare_hessian(Base.Fix2(cons_oop, i), soadtype, x)
201+
for i in 1:num_cons]
202+
else
203+
prep_cons_hess = [prepare_hessian(cons_oop, soadtype, x, Constant(i))
204+
for i in 1:num_cons]
205+
end
185206
else
186207
prep_cons_hess = nothing
187208
end
@@ -190,9 +211,17 @@ function instantiate_function(
190211
if f.cons !== nothing && f.cons_h === nothing && prep_cons_hess !== nothing
191212
# Standard cons_h! that returns array of matrices
192213
if cons_h == true
193-
cons_h! = function (H, θ)
194-
for i in 1:num_cons
195-
hessian!(cons_oop, H[i], prep_cons_hess[i], soadtype, θ, Constant(i))
214+
if adtype isa ADTypes.AutoSymbolics
215+
cons_h! = function (H, θ)
216+
for i in 1:num_cons
217+
hessian!(Base.Fix2(cons_oop, i), H[i], prep_cons_hess[i], soadtype, θ)
218+
end
219+
end
220+
else
221+
cons_h! = function (H, θ)
222+
for i in 1:num_cons
223+
hessian!(cons_oop, H[i], prep_cons_hess[i], soadtype, θ, Constant(i))
224+
end
196225
end
197226
end
198227
else

lib/OptimizationBase/src/adtypes.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ Hessian is not defined via Zygote.
182182
AutoZygote
183183

184184
function generate_adtype(adtype)
185-
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder ||
185+
if adtype isa AutoSymbolics || adtype isa AutoSparse{<:AutoSymbolics}
186+
soadtype = adtype
187+
elseif !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder ||
186188
adtype isa AutoZygote)
187189
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
188190
elseif adtype isa AutoZygote
@@ -233,7 +235,9 @@ function filled_spad(adtype)
233235
end
234236

235237
function generate_sparse_adtype(adtype)
236-
if !(adtype.dense_ad isa DifferentiationInterface.SecondOrder)
238+
if adtype isa AutoSparse{<:AutoSymbolics}
239+
soadtype = adtype
240+
elseif !(adtype.dense_ad isa DifferentiationInterface.SecondOrder)
237241
adtype = filled_spad(adtype)
238242
soadtype = spadtype_to_spsoadtype(adtype)
239243
else

0 commit comments

Comments
 (0)