Skip to content

Commit 3106f0a

Browse files
Pass hessian prototype to OptimizationFunction (#1606)
Co-authored-by: Christopher Rackauckas <[email protected]>
1 parent 57564bf commit 3106f0a

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/systems/optimization/optimizationsystem.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,17 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
184184
_hess = nothing
185185
end
186186

187+
if sparse
188+
hess_prototype = hessian_sparsity(sys)
189+
else
190+
hess_prototype = nothing
191+
end
192+
187193
_f = DiffEqBase.OptimizationFunction{iip}(f,
188194
SciMLBase.NoAD();
189195
grad = _grad,
190-
hess = _hess)
196+
hess = _hess,
197+
hess_prototype = hess_prototype)
191198

192199
defs = defaults(sys)
193200
defs = mergedefaults(defs, parammap, ps)
@@ -251,6 +258,12 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
251258
_hess = :nothing
252259
end
253260

261+
if sparse
262+
hess_prototype = hessian_sparsity(sys)
263+
else
264+
hess_prototype = nothing
265+
end
266+
254267
defs = defaults(sys)
255268
defs = mergedefaults(defs, parammap, ps)
256269
defs = mergedefaults(defs, u0map, dvs)
@@ -269,7 +282,8 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0,
269282
ub = $ub
270283
_f = OptimizationFunction{iip}(f, SciMLBase.NoAD();
271284
grad = grad,
272-
hess = hess)
285+
hess = hess,
286+
hess_prototype = hess_prototype)
273287
OptimizationProblem{$iip}(_f, u0, p; lb = lb, ub = ub, kwargs...)
274288
end
275289
end

test/optimizationsystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ calculate_hessian(combinedsys)
2121
generate_function(combinedsys)
2222
generate_gradient(combinedsys)
2323
generate_hessian(combinedsys)
24-
ModelingToolkit.hessian_sparsity(combinedsys)
24+
hess_sparsity = ModelingToolkit.hessian_sparsity(sys1)
25+
sparse_prob = OptimizationProblem(sys1, [x, y], [a, b], grad = true, sparse = true)
26+
@test sparse_prob.f.hess_prototype.rowval == hess_sparsity.rowval
27+
@test sparse_prob.f.hess_prototype.colptr == hess_sparsity.colptr
2528

2629
u0 = [sys1.x => 1.0
2730
sys1.y => 2.0

0 commit comments

Comments
 (0)