Skip to content

Commit d8e4810

Browse files
Improves OptimizationSystem (#1787)
* improves nested system handling for OptimizationSystem Co-authored-by: Yingbo Ma <[email protected]>
1 parent a457b3f commit d8e4810

File tree

5 files changed

+83
-28
lines changed

5 files changed

+83
-28
lines changed

.github/workflows/FormatCheck.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- name: Format check
3333
run: |
3434
julia -e '
35-
out = Cmd(`git diff --name-only`) |> read |> String
35+
out = Cmd(`git diff`) |> read |> String
3636
if out == ""
3737
exit(0)
3838
else

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ export JumpSystem
167167
export ODEProblem, SDEProblem
168168
export NonlinearFunction, NonlinearFunctionExpr
169169
export NonlinearProblem, BlockNonlinearProblem, NonlinearProblemExpr
170-
export OptimizationProblem, OptimizationProblemExpr
170+
export OptimizationProblem, OptimizationProblemExpr, constraints
171171
export AutoModelingToolkit
172172
export SteadyStateProblem, SteadyStateProblemExpr
173173
export JumpProblem, DiscreteProblem

src/systems/abstractsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ for prop in [:eqs
179179
:systems
180180
:structure
181181
:op
182-
:equality_constraints
183-
:inequality_constraints
182+
:constraints
184183
:controls
185184
:loss
186185
:bcs
@@ -1227,7 +1226,8 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
12271226
(; A, B, C, D)
12281227
end
12291228

1230-
function linearize(sys, inputs, outputs; op = Dict(), t = 0.0, allow_input_derivatives = false,
1229+
function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,
1230+
allow_input_derivatives = false,
12311231
kwargs...)
12321232
lin_fun, ssys = linearization_function(sys, inputs, outputs; kwargs...)
12331233
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys

src/systems/optimization/optimizationsystem.jl

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ $(FIELDS)
1010
1111
```julia
1212
@variables x y z
13-
@parameters σ ρ β
13+
@parameters a b c
1414
15-
op = σ*(y-x) + x*(ρ-z)-y + x*y - β*z
16-
@named os = OptimizationSystem(op, [x,y,z],[σ,ρ,β])
15+
op = a*(y-x) + x*(b-z)-y + x*y - c*z
16+
@named os = OptimizationSystem(op, [x,y,z], [a,b,c])
1717
```
1818
"""
1919
struct OptimizationSystem <: AbstractTimeIndependentSystem
20-
"""Vector of equations defining the system."""
20+
"""Objective function of the system."""
2121
op::Any
2222
"""Unknown variables."""
2323
states::Vector
@@ -26,18 +26,15 @@ struct OptimizationSystem <: AbstractTimeIndependentSystem
2626
"""Array variables."""
2727
var_to_name::Any
2828
observed::Vector{Equation}
29-
constraints::Vector
30-
"""
31-
Name: the name of the system. These are required to have unique names.
32-
"""
29+
"""List of constraint equations of the system."""
30+
constraints::Vector # {Union{Equation,Inequality}}
31+
"""The unique name of the system."""
3332
name::Symbol
34-
"""
35-
systems: The internal systems
36-
"""
33+
"""The internal systems."""
3734
systems::Vector{OptimizationSystem}
3835
"""
39-
defaults: The default values to use when initial conditions and/or
40-
parameters are not supplied in `ODEProblem`.
36+
The default values to use when initial guess and/or
37+
parameters are not supplied in `OptimizationProblem`.
4138
"""
4239
defaults::Dict
4340
"""
@@ -48,7 +45,7 @@ struct OptimizationSystem <: AbstractTimeIndependentSystem
4845
constraints, name, systems, defaults, metadata = nothing;
4946
checks::Union{Bool, Int} = true)
5047
if checks == true || (checks & CheckUnits) > 0
51-
check_units(op)
48+
unwrap(op) isa Symbolic && check_units(op)
5249
check_units(observed)
5350
all_dimensionless([states; ps]) || check_units(constraints)
5451
end
@@ -69,6 +66,11 @@ function OptimizationSystem(op, states, ps;
6966
metadata = nothing)
7067
name === nothing &&
7168
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
69+
70+
constraints = value.(scalarize(constraints))
71+
states′ = value.(scalarize(states))
72+
ps′ = value.(scalarize(ps))
73+
7274
if !(isempty(default_u0) && isempty(default_p))
7375
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
7476
:OptimizationSystem, force = true)
@@ -80,12 +82,12 @@ function OptimizationSystem(op, states, ps;
8082
defaults = todict(defaults)
8183
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
8284

83-
states, ps = value.(states), value.(ps)
8485
var_to_name = Dict()
85-
process_variables!(var_to_name, defaults, states)
86-
process_variables!(var_to_name, defaults, ps)
86+
process_variables!(var_to_name, defaults, states)
87+
process_variables!(var_to_name, defaults, ps)
8788
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
88-
OptimizationSystem(value(op), states, ps, var_to_name,
89+
90+
OptimizationSystem(value(op), states′, ps′, var_to_name,
8991
observed,
9092
constraints,
9193
name, systems, defaults, metadata; checks = checks)
@@ -124,10 +126,38 @@ function generate_function(sys::OptimizationSystem, vs = states(sys), ps = param
124126
end
125127

126128
function equations(sys::OptimizationSystem)
127-
isempty(get_systems(sys)) ? get_op(sys) :
128-
get_op(sys) + reduce(+, namespace_expr.(get_systems(sys)))
129+
op = get_op(sys)
130+
systems = get_systems(sys)
131+
if isempty(systems)
132+
op
133+
else
134+
op + reduce(+, map(sys_ -> namespace_expr(get_op(sys_), sys_), systems))
135+
end
136+
end
137+
138+
namespace_constraint(eq::Equation, sys) = namespace_equation(eq, sys)
139+
140+
# namespace_constraint(ineq::Inequality, sys) = namespace_inequality(ineq, sys)
141+
142+
# function namespace_inequality(ineq::Inequality, sys, n = nameof(sys))
143+
# _lhs = namespace_expr(ineq.lhs, sys, n)
144+
# _rhs = namespace_expr(ineq.rhs, sys, n)
145+
# Inequality(
146+
# namespace_expr(_lhs, sys, n),
147+
# namespace_expr(_rhs, sys, n),
148+
# ineq.relational_op,
149+
# )
150+
# end
151+
152+
function namespace_constraints(sys::OptimizationSystem)
153+
namespace_constraint.(get_constraints(sys), Ref(sys))
154+
end
155+
156+
function constraints(sys::OptimizationSystem)
157+
cs = get_constraints(sys)
158+
systems = get_systems(sys)
159+
isempty(systems) ? cs : [cs; reduce(vcat, namespace_constraints.(systems))]
129160
end
130-
namespace_expr(sys::OptimizationSystem) = namespace_expr(get_op(sys), sys)
131161

132162
hessian_sparsity(sys::OptimizationSystem) = hessian_sparsity(get_op(sys), states(sys))
133163

@@ -168,6 +198,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
168198
kwargs...) where {iip}
169199
dvs = states(sys)
170200
ps = parameters(sys)
201+
cstr = constraints(sys)
171202

172203
defs = defaults(sys)
173204
defs = mergedefaults(defs, parammap, ps)
@@ -216,8 +247,8 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
216247
hess_prototype = nothing
217248
end
218249

219-
if length(sys.constraints) > 0
220-
@named cons_sys = NonlinearSystem(sys.constraints, dvs, ps)
250+
if length(cstr) > 0
251+
@named cons_sys = NonlinearSystem(cstr, dvs, ps)
221252
cons = generate_function(cons_sys, checkbounds = checkbounds,
222253
linenumbers = linenumbers,
223254
expression = Val{false})[2]
@@ -237,6 +268,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
237268

238269
_f = DiffEqBase.OptimizationFunction{iip}(f,
239270
sys = sys,
271+
syms = nameof.(states(sys)),
240272
SciMLBase.NoAD();
241273
grad = _grad,
242274
hess = _hess,
@@ -251,6 +283,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
251283
else
252284
_f = DiffEqBase.OptimizationFunction{iip}(f,
253285
sys = sys,
286+
syms = nameof.(states(sys)),
254287
SciMLBase.NoAD();
255288
grad = _grad,
256289
hess = _hess,

test/optimizationsystem.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,25 @@ OBS2 = OBS
103103
@test isequal(OBS2, @nonamespace sys2.OBS)
104104
@unpack OBS = sys2
105105
@test isequal(OBS2, OBS)
106+
107+
# nested constraints
108+
@testset "nested systems" begin
109+
@variables x y
110+
o1 = (x - 1)^2
111+
o2 = (y - 1 / 2)^2
112+
c1 = [
113+
x ~ 1,
114+
]
115+
c2 = [
116+
y ~ 1,
117+
]
118+
sys1 = OptimizationSystem(o1, [x], [], name = :sys1, constraints = c1)
119+
sys2 = OptimizationSystem(o2, [y], [], name = :sys2, constraints = c2)
120+
sys = OptimizationSystem(0, [], []; name = :sys, systems = [sys1, sys2],
121+
constraints = [sys1.x + sys2.y ~ 2], checks = false)
122+
prob = OptimizationProblem(sys, [0.0, 0.0])
123+
124+
@test isequal(constraints(sys), vcat(sys1.x + sys2.y ~ 2, sys1.x ~ 1, sys2.y ~ 1))
125+
@test isequal(equations(sys), (sys1.x - 1)^2 + (sys2.y - 1 / 2)^2)
126+
@test isequal(states(sys), [sys1.x, sys2.y])
127+
end

0 commit comments

Comments
 (0)