Skip to content

Commit dca0f83

Browse files
Merge pull request #2492 from SciML/optproboop
OptimizationSystem: Generate oop constraint derivatives
2 parents a132efd + e7eddb7 commit dca0f83

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

src/systems/optimization/optimizationsystem.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,32 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
364364
linenumbers = linenumbers,
365365
expression = Val{false})
366366
if cons_j
367-
_cons_j = generate_jacobian(cons_sys; expression = Val{false},
368-
sparse = cons_sparse)[2]
367+
_cons_j = let (cons_jac_oop, cons_jac_iip) = generate_jacobian(cons_sys;
368+
checkbounds = checkbounds,
369+
linenumbers = linenumbers,
370+
parallel = parallel, expression = Val{false},
371+
sparse = cons_sparse)
372+
_cons_j(u, p) = cons_jac_oop(u, p)
373+
_cons_j(J, u, p) = (cons_jac_iip(J, u, p); J)
374+
_cons_j(u, p::MTKParameters) = cons_jac_oop(u, p...)
375+
_cons_j(J, u, p::MTKParameters) = (cons_jac_iip(J, u, p...); J)
376+
_cons_j
377+
end
369378
else
370379
_cons_j = nothing
371380
end
372381
if cons_h
373-
_cons_h = generate_hessian(cons_sys; expression = Val{false},
374-
sparse = cons_sparse)[2]
382+
_cons_h = let (cons_hess_oop, cons_hess_iip) = generate_hessian(
383+
cons_sys, checkbounds = checkbounds,
384+
linenumbers = linenumbers,
385+
sparse = cons_sparse, parallel = parallel,
386+
expression = Val{false})
387+
_cons_h(u, p) = cons_hess_oop(u, p)
388+
_cons_h(J, u, p) = (cons_hess_iip(J, u, p); J)
389+
_cons_h(u, p::MTKParameters) = cons_hess_oop(u, p...)
390+
_cons_h(J, u, p::MTKParameters) = (cons_hess_iip(J, u, p...); J)
391+
_cons_h
392+
end
375393
else
376394
_cons_h = nothing
377395
end

test/optimizationsystem.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,31 @@ end
312312
@test all(prob.f.cons_expr[i].lhs isa Symbolics.Symbolic
313313
for i in 1:length(prob.f.cons_expr))
314314
end
315+
316+
@testset "Derivatives, iip and oop" begin
317+
@variables x y
318+
@parameters a b
319+
loss = (a - x)^2 + b * (y - x^2)^2
320+
cons2 = [x^2 + y^2 ~ 0, y * sin(x) - x ~ 0]
321+
sys = complete(OptimizationSystem(
322+
loss, [x, y], [a, b], name = :sys2, constraints = cons2))
323+
prob = OptimizationProblem(sys, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0],
324+
grad = true, hess = true, cons_j = true, cons_h = true)
325+
326+
G1 = Array{Float64}(undef, 2)
327+
H1 = Array{Float64}(undef, 2, 2)
328+
J = Array{Float64}(undef, 2, 2)
329+
H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)]
330+
331+
prob.f.grad(G1, [1.0, 1.0], [1.0, 100.0])
332+
@test prob.f.grad([1.0, 1.0], [1.0, 100.0]) == G1
333+
334+
prob.f.hess(H1, [1.0, 1.0], [1.0, 100.0])
335+
@test prob.f.hess([1.0, 1.0], [1.0, 100.0]) == H1
336+
337+
prob.f.cons_j(J, [1.0, 1.0], [1.0, 100.0])
338+
@test prob.f.cons_j([1.0, 1.0], [1.0, 100.0]) == J
339+
340+
prob.f.cons_h(H3, [1.0, 1.0], [1.0, 100.0])
341+
@test prob.f.cons_h([1.0, 1.0], [1.0, 100.0]) == H3
342+
end

0 commit comments

Comments
 (0)