Skip to content

Commit 7854da1

Browse files
authored
Merge pull request #1927 from SciML/myb/opt
Add observed support for optimization systems
2 parents 2fa9b52 + 7758910 commit 7854da1

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

src/systems/optimization/optimizationsystem.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,21 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
288288
hess_prototype = nothing
289289
end
290290

291+
observedfun = let sys = sys, dict = Dict()
292+
function generated_observed(obsvar, args...)
293+
obs = get!(dict, value(obsvar)) do
294+
build_explicit_observed_function(sys, obsvar)
295+
end
296+
if args === ()
297+
let obs = obs
298+
(u, p) -> obs(u, p)
299+
end
300+
else
301+
obs(args...)
302+
end
303+
end
304+
end
305+
291306
if length(cstr) > 0
292307
@named cons_sys = ConstraintsSystem(cstr, dvs, ps)
293308
cons, lcons_, ucons_ = generate_function(cons_sys, checkbounds = checkbounds,
@@ -334,7 +349,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
334349
cons_jac_prototype = cons_jac_prototype,
335350
cons_hess_prototype = cons_hess_prototype,
336351
expr = obj_expr,
337-
cons_expr = cons_expr)
352+
cons_expr = cons_expr,
353+
observed = observedfun)
338354
OptimizationProblem{iip}(_f, u0, p; lb = lb, ub = ub, int = int,
339355
lcons = lcons, ucons = ucons, kwargs...)
340356
else
@@ -346,7 +362,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
346362
syms = Symbol.(states(sys)),
347363
paramsyms = Symbol.(parameters(sys)),
348364
hess_prototype = hess_prototype,
349-
expr = obj_expr)
365+
expr = obj_expr,
366+
observed = observedfun)
350367
OptimizationProblem{iip}(_f, u0, p; lb = lb, ub = ub, int = int,
351368
kwargs...)
352369
end

test/optimizationsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ end
8181
sol = solve(prob, IPNewton())
8282
@test sol.minimum < 1.0
8383
@test sol.u[0.808, -0.064] atol=1e-3
84-
@test_broken sol[x]^2 + sol[y]^2 1.0
84+
@test sol[x]^2 + sol[y]^2 1.0
8585
sol = solve(prob, Ipopt.Optimizer(); print_level = 0)
8686
@test sol.minimum < 1.0
8787
@test sol.u[0.808, -0.064] atol=1e-3
88-
@test_broken sol[x]^2 + sol[y]^2 1.0
88+
@test sol[x]^2 + sol[y]^2 1.0
8989
sol = solve(prob, AmplNLWriter.Optimizer(Ipopt_jll.amplexe))
9090
@test sol.minimum < 1.0
9191
@test sol.u[0.808, -0.064] atol=1e-3
92-
@test_broken sol[x]^2 + sol[y]^2 1.0
92+
@test sol[x]^2 + sol[y]^2 1.0
9393
end
9494

9595
@testset "rosenbrock" begin

0 commit comments

Comments
 (0)