Skip to content

Commit e5e55c1

Browse files
fix hessian calculations
1 parent 997f973 commit e5e55c1

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

src/systems/optimization/optimizationsystem.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ end
6363

6464
function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys);
6565
sparse = false, kwargs...)
66-
hes = calculate_hessian(sys)
6766
if sparse
68-
hes = sparse(hes)
67+
hess = sparsehessian(equations(sys),[dv() for dv in states(sys)])
68+
else
69+
hess = calculate_hessian(sys)
6970
end
70-
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps);
71+
return build_function(hess, convert.(Variable,vs), convert.(Variable,ps);
7172
conv = AbstractSysToExpr(sys),kwargs...)
7273
end
7374

@@ -89,6 +90,10 @@ struct AutoModelingToolkit <: DiffEqBase.AbstractADType end
8990
DiffEqBase.OptimizationProblem(sys::OptimizationSystem,args...;kwargs...) =
9091
DiffEqBase.OptimizationProblem{true}(sys::OptimizationSystem,args...;kwargs...)
9192

93+
OptimizationProblemExpr(sys::OptimizationSystem,args...;kwargs...) =
94+
OptimizationProblemExpr{true}(sys::OptimizationSystem,args...;kwargs...)
95+
96+
9297
"""
9398
```julia
9499
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
@@ -166,8 +171,8 @@ struct OptimizationProblemExpr{iip} end
166171
function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
167172
parammap=DiffEqBase.NullParameters();
168173
lb=nothing, ub=nothing,
169-
grad = true,
170-
hes = false, sparse = false,
174+
grad = false,
175+
hess = false, sparse = false,
171176
checkbounds = false,
172177
linenumbers = false, parallel=SerialForm(),
173178
kwargs...) where iip

test/modelingtoolkitize.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ x0 = zeros(2)
4848
p = [1.0,100.0]
4949

5050
prob = OptimizationProblem(rosenbrock,x0,p)
51-
sys = modelingtoolkitize(prob)
52-
x0map = states(sys) .=> x0
53-
parammap = parameters(sys) .=> p
51+
sys = modelingtoolkitize(prob) # symbolicitize me captain!
5452

55-
prob = OptimizationProblem(sys,x0map,parammap,grad=true)
53+
prob = OptimizationProblem(sys,x0,p,grad=true,hess=true)
5654
sol = solve(prob,NelderMead())
5755
@test sol.minimum < 1e-8
5856

5957
sol = solve(prob,BFGS())
6058
@test sol.minimum < 1e-8
59+
60+
sol = solve(prob,Newton())
61+
@test sol.minimum < 1e-8

0 commit comments

Comments
 (0)