diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 39c5c1e7a7..801e7b05f3 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -284,6 +284,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, linenumbers = true, parallel = SerialForm(), eval_expression = false, eval_module = @__MODULE__, use_union = false, + checks = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed `OptimizationSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `OptimizationProblem`") @@ -393,12 +394,17 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) if length(cstr) > 0 - @named cons_sys = ConstraintsSystem(cstr, dvs, ps) + @named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks) cons_sys = complete(cons_sys) cons, lcons_, ucons_ = generate_function(cons_sys, checkbounds = checkbounds, linenumbers = linenumbers, expression = Val{true}) - cons = eval_or_rgf.(cons; eval_expression, eval_module) + cons = let (cons_oop, cons_iip) = eval_or_rgf.(cons; eval_expression, eval_module) + _cons(u, p) = cons_oop(u, p) + _cons(resid, u, p) = cons_iip(resid, u, p) + _cons(u, p::MTKParameters) = cons_oop(u, p...) + _cons(resid, u, p::MTKParameters) = cons_iip(resid, u, p...) + end if cons_j _cons_j = let (cons_jac_oop, cons_jac_iip) = eval_or_rgf.( generate_jacobian(cons_sys; @@ -464,7 +470,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, grad = _grad, hess = _hess, hess_prototype = hess_prototype, - cons = cons[2], + cons = cons, cons_j = _cons_j, cons_h = _cons_h, cons_jac_prototype = cons_jac_prototype, diff --git a/test/optimizationsystem.jl b/test/optimizationsystem.jl index a8e2be936d..bb59fb09d9 100644 --- a/test/optimizationsystem.jl +++ b/test/optimizationsystem.jl @@ -368,3 +368,12 @@ end @test is_variable(sys, x[2]) @test is_variable(sys, x[3]) end + +@testset "Constraints work with nonnumeric parameters" begin + @variables x + @parameters p f(::Real) + @mtkbuild sys = OptimizationSystem( + x^2 + f(x) * p, [x], [f, p]; constraints = [2.0 ≲ f(x) + p]) + prob = OptimizationProblem(sys, [x => 1.0], [p => 1.0, f => (x -> 2x)]) + @test abs(prob.f.cons(prob.u0, prob.p)[1]) ≈ 1.0 +end