Skip to content

Commit 1b46612

Browse files
Merge pull request #3241 from aml5600/pass-checks-opt-prob
OptimizationProblem updates
2 parents eda23d4 + d6add0e commit 1b46612

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

src/systems/optimization/optimizationsystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
284284
linenumbers = true, parallel = SerialForm(),
285285
eval_expression = false, eval_module = @__MODULE__,
286286
use_union = false,
287+
checks = true,
287288
kwargs...) where {iip}
288289
if !iscomplete(sys)
289290
error("A completed `OptimizationSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `OptimizationProblem`")
@@ -393,12 +394,17 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
393394
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
394395

395396
if length(cstr) > 0
396-
@named cons_sys = ConstraintsSystem(cstr, dvs, ps)
397+
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)
397398
cons_sys = complete(cons_sys)
398399
cons, lcons_, ucons_ = generate_function(cons_sys, checkbounds = checkbounds,
399400
linenumbers = linenumbers,
400401
expression = Val{true})
401-
cons = eval_or_rgf.(cons; eval_expression, eval_module)
402+
cons = let (cons_oop, cons_iip) = eval_or_rgf.(cons; eval_expression, eval_module)
403+
_cons(u, p) = cons_oop(u, p)
404+
_cons(resid, u, p) = cons_iip(resid, u, p)
405+
_cons(u, p::MTKParameters) = cons_oop(u, p...)
406+
_cons(resid, u, p::MTKParameters) = cons_iip(resid, u, p...)
407+
end
402408
if cons_j
403409
_cons_j = let (cons_jac_oop, cons_jac_iip) = eval_or_rgf.(
404410
generate_jacobian(cons_sys;
@@ -464,7 +470,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
464470
grad = _grad,
465471
hess = _hess,
466472
hess_prototype = hess_prototype,
467-
cons = cons[2],
473+
cons = cons,
468474
cons_j = _cons_j,
469475
cons_h = _cons_h,
470476
cons_jac_prototype = cons_jac_prototype,

test/optimizationsystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,12 @@ end
368368
@test is_variable(sys, x[2])
369369
@test is_variable(sys, x[3])
370370
end
371+
372+
@testset "Constraints work with nonnumeric parameters" begin
373+
@variables x
374+
@parameters p f(::Real)
375+
@mtkbuild sys = OptimizationSystem(
376+
x^2 + f(x) * p, [x], [f, p]; constraints = [2.0 f(x) + p])
377+
prob = OptimizationProblem(sys, [x => 1.0], [p => 1.0, f => (x -> 2x)])
378+
@test abs(prob.f.cons(prob.u0, prob.p)[1]) 1.0
379+
end

0 commit comments

Comments
 (0)