Skip to content

Commit 138146b

Browse files
fix: fix sparse cost hessian
1 parent 9b564c1 commit 138146b

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/systems/codegen.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,10 +667,9 @@ function calculate_cost_hessian(sys::System; sparse = false, simplify = false)
667667
obj = cost(sys)
668668
dvs = unknowns(sys)
669669
if sparse
670-
exprs = Symbolics.sparsehessian(obj, dvs; simplify)::AbstractSparseArray
671-
sparsity = similar(exprs, Float64)
670+
return Symbolics.sparsehessian(obj, dvs; simplify)::AbstractSparseArray
672671
else
673-
exprs = Symbolics.hessian(obj, dvs; simplify)
672+
return Symbolics.hessian(obj, dvs; simplify)
674673
end
675674
end
676675

test/optimizationsystem.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,21 @@ end
390390
obj = myeigvals_1(m)
391391
@test_nowarn OptimizationSystem(obj, p_free, []; name = :osys)
392392
end
393+
394+
@testset "Test sparse hessian" begin
395+
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
396+
@variables x[1:2]
397+
@named sys = OptimizationSystem(rosenbrock(x))
398+
sys = complete(sys)
399+
prob = OptimizationProblem(sys, [x => [42.0, 12.37]]; hess = true, sparse = true)
400+
401+
symbolic_hess = Symbolics.hessian(cost(sys), x)
402+
symbolic_hess_value = Symbolics.fast_substitute(symbolic_hess, Dict(x[1] => prob[x[1]], x[2] => prob[x[2]]))
403+
404+
oop_hess = prob.f.hess(prob.u0, prob.p)
405+
@test oop_hess symbolic_hess_value
406+
407+
iip_hess = similar(prob.f.hess_prototype)
408+
prob.f.hess(iip_hess, prob.u0, prob.p)
409+
@test iip_hess symbolic_hess_value
410+
end

0 commit comments

Comments
 (0)