Skip to content

Commit 0db11af

Browse files
Merge pull request #3940 from AayushSabharwal/as/fix-optprob
fix: fix sparse cost hessian
2 parents 9b564c1 + f60334a commit 0db11af

File tree

4 files changed

+30
-13
lines changed

4 files changed

+30
-13
lines changed

src/systems/codegen.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,10 +667,9 @@ function calculate_cost_hessian(sys::System; sparse = false, simplify = false)
667667
obj = cost(sys)
668668
dvs = unknowns(sys)
669669
if sparse
670-
exprs = Symbolics.sparsehessian(obj, dvs; simplify)::AbstractSparseArray
671-
sparsity = similar(exprs, Float64)
670+
return Symbolics.sparsehessian(obj, dvs; simplify)::AbstractSparseArray
672671
else
673-
exprs = Symbolics.hessian(obj, dvs; simplify)
672+
return Symbolics.hessian(obj, dvs; simplify)
674673
end
675674
end
676675

test/initializationsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using SymbolicIndexingInterface, SciMLStructures
55
using SciMLStructures: Tunable
66
using ModelingToolkit: t_nounits as t, D_nounits as D, observed
77
using DynamicQuantities
8+
using DiffEqBase: BrownFullBasicInit
89

910
@parameters g
1011
@variables x(t) y(t) [state_priority = 10] λ(t)

test/odesystem.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -565,33 +565,32 @@ let
565565
eqs = [D(x[1]) ~ x[2]
566566
D(x[2]) ~ -x[1] - 0.5 * x[2] + k
567567
y ~ 0.9 * x[1] + x[2]]
568-
@named sys = System(eqs, t, vcat(x, [y]), [k], defaults = Dict(x .=> 0))
568+
@named sys = System(eqs, t, vcat(x, [y]), [k])
569569
sys = mtkcompile(sys)
570570

571571
u0 = x .=> [0.5, 0]
572572
du0 = D.(x) .=> 0.0
573-
prob = DAEProblem(sys, [du0; u0], (0, 50))
574-
@test prob[x] [0.5, 0.0]
573+
prob = DAEProblem(sys, du0, (0, 50); guesses = u0)
574+
@test prob[x] [0.5, 1.0]
575575
@test prob.du0 [0.0, 0.0]
576576
@test prob.p isa MTKParameters
577577
@test prob.ps[k] 1
578578
sol = solve(prob, IDA())
579579
@test sol[y] 0.9 * sol[x[1]] + sol[x[2]]
580580
@test isapprox(sol[x[1]][end], 1, atol = 1e-3)
581581

582-
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0, x[1] => 0.5],
583-
(0, 50))
582+
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0], (0, 50); guesses = u0)
584583

585-
@test prob[x] [0.5, 0]
584+
@test prob[x] [0.5, 1]
586585
@test prob.du0 [0, 0]
587586
@test prob.p isa MTKParameters
588587
@test prob.ps[k] 1
589588
sol = solve(prob, IDA())
590589
@test isapprox(sol[x[1]][end], 1, atol = 1e-3)
591590

592-
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0, x[1] => 0.5, k => 2],
593-
(0, 50))
594-
@test prob[x] [0.5, 0]
591+
prob = DAEProblem(sys, [D(y) => 0, D(x[1]) => 0, D(x[2]) => 0, k => 2],
592+
(0, 50); guesses = u0)
593+
@test prob[x] [0.5, 3]
595594
@test prob.du0 [0, 0]
596595
@test prob.p isa MTKParameters
597596
@test prob.ps[k] 2
@@ -600,7 +599,7 @@ let
600599

601600
# no initial conditions for D(x[1]) and D(x[2]) provided
602601
@test_throws ModelingToolkit.MissingVariablesError prob=DAEProblem(
603-
sys, Pair[], (0, 50))
602+
sys, Pair[], (0, 50); guesses = u0)
604603

605604
prob = ODEProblem(sys, Pair[x[1] => 0], (0, 50))
606605
sol = solve(prob, Rosenbrock23())

test/optimizationsystem.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,21 @@ end
390390
obj = myeigvals_1(m)
391391
@test_nowarn OptimizationSystem(obj, p_free, []; name = :osys)
392392
end
393+
394+
@testset "Test sparse hessian" begin
395+
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
396+
@variables x[1:2]
397+
@named sys = OptimizationSystem(rosenbrock(x))
398+
sys = complete(sys)
399+
prob = OptimizationProblem(sys, [x => [42.0, 12.37]]; hess = true, sparse = true)
400+
401+
symbolic_hess = Symbolics.hessian(cost(sys), x)
402+
symbolic_hess_value = Symbolics.fast_substitute(symbolic_hess, Dict(x[1] => prob[x[1]], x[2] => prob[x[2]]))
403+
404+
oop_hess = prob.f.hess(prob.u0, prob.p)
405+
@test oop_hess symbolic_hess_value
406+
407+
iip_hess = similar(prob.f.hess_prototype)
408+
prob.f.hess(iip_hess, prob.u0, prob.p)
409+
@test iip_hess symbolic_hess_value
410+
end

0 commit comments

Comments
 (0)