Skip to content

Commit 50e3461

Browse files
authored
Merge pull request #1866 from SciML/myb/repeated
Repeated components optimization and make `@named` more intuitive
2 parents d9b4b4a + 6c3c0d3 commit 50e3461

File tree

10 files changed

+114
-50
lines changed

10 files changed

+114
-50
lines changed

src/systems/abstractsystem.jl

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const SYSTEM_COUNT = Threads.Atomic{UInt}(0)
2+
13
"""
24
```julia
35
calculate_tgrad(sys::AbstractTimeDependentSystem)
@@ -175,6 +177,7 @@ function complete(sys::AbstractSystem)
175177
end
176178

177179
for prop in [:eqs
180+
:tag
178181
:noiseeqs
179182
:iv
180183
:states
@@ -255,7 +258,7 @@ end
255258
end
256259
end
257260

258-
rename(x::AbstractSystem, name) = @set x.name = name
261+
rename(x, name) = @set x.name = name
259262

260263
function Base.propertynames(sys::AbstractSystem; private = false)
261264
if private
@@ -842,12 +845,26 @@ function _named(name, call, runtime = false)
842845
end
843846
end
844847

848+
is_sys_construction = Symbol("###__is_system_construction###")
845849
kws = call.args[2].args
850+
for (i, kw) in enumerate(kws)
851+
if Meta.isexpr(kw, (:(=), :kw))
852+
kw.args[2] = :($is_sys_construction ? $(kw.args[2]) :
853+
$default_to_parentscope($(kw.args[2])))
854+
elseif kw isa Symbol
855+
rhs = :($is_sys_construction ? $(kw) : $default_to_parentscope($(kw)))
856+
kws[i] = Expr(:kw, kw, rhs)
857+
end
858+
end
846859

847860
if !any(kw -> (kw isa Symbol ? kw : kw.args[1]) == :name, kws) # don't overwrite `name` kwarg
848861
pushfirst!(kws, Expr(:kw, :name, runtime ? name : Meta.quot(name)))
849862
end
850-
call
863+
op = call.args[1]
864+
quote
865+
$is_sys_construction = ($op isa $DataType) && ($op <: $AbstractSystem)
866+
$call
867+
end
851868
end
852869

853870
function _named_idxs(name::Symbol, idxs, call)
@@ -872,46 +889,44 @@ end
872889
"""
873890
@named y = foo(x)
874891
@named y[1:10] = foo(x)
875-
@named y 1:10 i -> foo(x*i)
876-
877-
Rewrite `@named y = foo(x)` to `y = foo(x; name=:y)`.
878-
879-
Rewrite `@named y[1:10] = foo(x)` to `y = map(i′->foo(x; name=Symbol(:y_, i′)), 1:10)`.
892+
@named y 1:10 i -> foo(x*i) # This is not recommended
880893
881-
Rewrite `@named y 1:10 i -> foo(x*i)` to `y = map(i->foo(x*i; name=Symbol(:y_, i)), 1:10)`.
894+
Pass the LHS name to the model. When it's calling anything that's not an
895+
AbstractSystem, it wraps all keyword arguments in `default_to_parentscope` so
896+
that namespacing works intuitively when passing a symbolic default into a
897+
component.
882898
883899
Examples:
884-
```julia
900+
```julia-repl
885901
julia> using ModelingToolkit
886902
887-
julia> foo(i; name) = i, name
903+
julia> foo(i; name) = (; i, name)
888904
foo (generic function with 1 method)
889905
890906
julia> x = 41
891907
41
892908
893909
julia> @named y = foo(x)
894-
(41, :y)
910+
(i = 41, name = :y)
895911
896912
julia> @named y[1:3] = foo(x)
897-
3-element Vector{Tuple{Int64, Symbol}}:
898-
(41, :y_1)
899-
(41, :y_2)
900-
(41, :y_3)
901-
902-
julia> @named y 1:3 i -> foo(x*i)
903-
3-element Vector{Tuple{Int64, Symbol}}:
904-
(41, :y_1)
905-
(82, :y_2)
906-
(123, :y_3)
913+
3-element Vector{NamedTuple{(:i, :name), Tuple{Int64, Symbol}}}:
914+
(i = 41, name = :y_1)
915+
(i = 41, name = :y_2)
916+
(i = 41, name = :y_3)
907917
```
908918
"""
909919
macro named(expr)
910920
name, call = split_assign(expr)
911921
if Meta.isexpr(name, :ref)
912922
name, idxs = name.args
913923
check_name(name)
914-
esc(_named_idxs(name, idxs, :($(gensym()) -> $call)))
924+
var = gensym(name)
925+
ex = quote
926+
$var = $(_named(name, call))
927+
$name = map(i -> $rename($var, Symbol($(Meta.quot(name)), :_, i)), $idxs)
928+
end
929+
esc(ex)
915930
else
916931
check_name(name)
917932
esc(:($name = $(_named(name, call))))
@@ -922,6 +937,11 @@ macro named(name::Symbol, idxs, call)
922937
esc(_named_idxs(name, idxs, call))
923938
end
924939

940+
function default_to_parentscope(v)
941+
uv = unwrap(v)
942+
uv isa Symbolic && !hasmetadata(uv, SymScope) ? ParentScope(v) : v
943+
end
944+
925945
function _config(expr, namespace)
926946
cn = Base.Fix2(_config, namespace)
927947
if Meta.isexpr(expr, :.)

src/systems/diffeqs/odesystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ eqs = [D(x) ~ σ*(y-x),
2323
```
2424
"""
2525
struct ODESystem <: AbstractODESystem
26+
"""
27+
tag: a tag for the system. If two system have the same tag, then they are
28+
structurally identical.
29+
"""
30+
tag::UInt
2631
"""The ODEs defining the system."""
2732
eqs::Vector{Equation}
2833
"""Independent variable."""
@@ -120,7 +125,7 @@ struct ODESystem <: AbstractODESystem
120125
"""
121126
complete::Bool
122127

123-
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
128+
function ODESystem(tag, deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
124129
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
125130
torn_matching, connector_type, preface, cevents,
126131
devents, metadata = nothing, tearing_state = nothing,
@@ -135,7 +140,7 @@ struct ODESystem <: AbstractODESystem
135140
if checks == true || (checks & CheckUnits) > 0
136141
all_dimensionless([dvs; ps; iv]) || check_units(deqs)
137142
end
138-
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
143+
new(tag, deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
139144
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
140145
connector_type, preface, cevents, devents, metadata, tearing_state,
141146
substitutions, complete)
@@ -189,7 +194,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
189194
end
190195
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
191196
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
192-
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
197+
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
198+
deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
193199
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing,
194200
connector_type, preface, cont_callbacks, disc_callbacks,
195201
metadata, checks = checks)

src/systems/diffeqs/sdesystem.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ noiseeqs = [0.1*x,
2727
```
2828
"""
2929
struct SDESystem <: AbstractODESystem
30+
"""
31+
tag: a tag for the system. If two system have the same tag, then they are
32+
structurally identical.
33+
"""
34+
tag::UInt
3035
"""The expressions defining the drift term."""
3136
eqs::Vector{Equation}
3237
"""The expressions defining the diffusion term."""
@@ -105,7 +110,8 @@ struct SDESystem <: AbstractODESystem
105110
"""
106111
complete::Bool
107112

108-
function SDESystem(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
113+
function SDESystem(tag, deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad,
114+
jac,
109115
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
110116
cevents, devents, metadata = nothing, complete = false;
111117
checks::Union{Bool, Int} = true)
@@ -118,7 +124,8 @@ struct SDESystem <: AbstractODESystem
118124
if checks == true || (checks & CheckUnits) > 0
119125
all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs)
120126
end
121-
new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac,
127+
new(tag, deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac,
128+
ctrl_jac,
122129
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
123130
metadata, complete)
124131
end
@@ -169,7 +176,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
169176
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
170177
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
171178

172-
SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
179+
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
180+
deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac,
173181
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
174182
cont_callbacks, disc_callbacks, metadata; checks = checks)
175183
end

src/systems/discrete_system/discrete_system.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ eqs = [D(x) ~ σ*(y-x),
2424
```
2525
"""
2626
struct DiscreteSystem <: AbstractTimeDependentSystem
27+
"""
28+
tag: a tag for the system. If two system have the same tag, then they are
29+
structurally identical.
30+
"""
31+
tag::UInt
2732
"""The differential equations defining the discrete system."""
2833
eqs::Vector{Equation}
2934
"""Independent variable."""
@@ -76,7 +81,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
7681
"""
7782
complete::Bool
7883

79-
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name,
84+
function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed,
85+
name,
8086
systems, defaults, preface, connector_type,
8187
metadata = nothing,
8288
tearing_state = nothing, substitutions = nothing,
@@ -88,7 +94,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
8894
if checks == true || (checks & CheckUnits) > 0
8995
all_dimensionless([dvs; ps; iv; ctrls]) || check_units(discreteEqs)
9096
end
91-
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, defaults,
97+
new(tag, discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems,
98+
defaults,
9299
preface, connector_type, metadata, tearing_state, substitutions, complete)
93100
end
94101
end
@@ -134,7 +141,8 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
134141
if length(unique(sysnames)) != length(sysnames)
135142
throw(ArgumentError("System names must be unique."))
136143
end
137-
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems,
144+
DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
145+
eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems,
138146
defaults, preface, connector_type, metadata, kwargs...)
139147
end
140148

src/systems/jumps/jumpsystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ j₃ = MassActionJump(2*β+γ, [R => 1], [S => 1, R => -1])
4848
```
4949
"""
5050
struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
51+
"""
52+
tag: a tag for the system. If two system have the same tag, then they are
53+
structurally identical.
54+
"""
55+
tag::UInt
5156
"""
5257
The jumps of the system. Allowable types are `ConstantRateJump`,
5358
`VariableRateJump`, `MassActionJump`.
@@ -92,7 +97,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
9297
"""
9398
complete::Bool
9499

95-
function JumpSystem{U}(ap::U, iv, states, ps, var_to_name, observed, name, systems,
100+
function JumpSystem{U}(tag, ap::U, iv, states, ps, var_to_name, observed, name, systems,
96101
defaults, connector_type, devents,
97102
metadata = nothing, complete = false;
98103
checks::Union{Bool, Int} = true) where {U <: ArrayPartition}
@@ -103,7 +108,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
103108
if checks == true || (checks & CheckUnits) > 0
104109
all_dimensionless([states; ps; iv]) || check_units(ap, iv)
105110
end
106-
new{U}(ap, iv, states, ps, var_to_name, observed, name, systems, defaults,
111+
new{U}(tag, ap, iv, states, ps, var_to_name, observed, name, systems, defaults,
107112
connector_type, devents, metadata, complete)
108113
end
109114
end
@@ -156,7 +161,8 @@ function JumpSystem(eqs, iv, states, ps;
156161
error("JumpSystems currently only support discrete events.")
157162
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
158163

159-
JumpSystem{typeof(ap)}(ap, value(iv), states, ps, var_to_name, observed, name, systems,
164+
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
165+
ap, value(iv), states, ps, var_to_name, observed, name, systems,
160166
defaults, connector_type, disc_callbacks, metadata,
161167
checks = checks)
162168
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ eqs = [0 ~ σ*(y-x),
1919
```
2020
"""
2121
struct NonlinearSystem <: AbstractTimeIndependentSystem
22+
"""
23+
tag: a tag for the system. If two system have the same tag, then they are
24+
structurally identical.
25+
"""
26+
tag::UInt
2227
"""Vector of equations defining the system."""
2328
eqs::Vector{Equation}
2429
"""Unknown variables."""
@@ -67,14 +72,15 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem
6772
"""
6873
complete::Bool
6974

70-
function NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems,
75+
function NonlinearSystem(tag, eqs, states, ps, var_to_name, observed, jac, name,
76+
systems,
7177
defaults, connector_type, metadata = nothing,
7278
tearing_state = nothing, substitutions = nothing,
7379
complete = false; checks::Union{Bool, Int} = true)
7480
if checks == true || (checks & CheckUnits) > 0
7581
all_dimensionless([states; ps]) || check_units(eqs)
7682
end
77-
new(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
83+
new(tag, eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
7884
connector_type, metadata, tearing_state, substitutions, complete)
7985
end
8086
end
@@ -124,7 +130,8 @@ function NonlinearSystem(eqs, states, ps;
124130
process_variables!(var_to_name, defaults, ps)
125131
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
126132

127-
NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
133+
NonlinearSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
134+
eqs, states, ps, var_to_name, observed, jac, name, systems, defaults,
128135
connector_type, metadata, checks = checks)
129136
end
130137

src/systems/optimization/optimizationsystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ op = a*(y-x) + x*(b-z)-y + x*y - c*z
1717
```
1818
"""
1919
struct OptimizationSystem <: AbstractTimeIndependentSystem
20+
"""
21+
tag: a tag for the system. If two system have the same tag, then they are
22+
structurally identical.
23+
"""
24+
tag::UInt
2025
"""Objective function of the system."""
2126
op::Any
2227
"""Unknown variables."""
@@ -46,15 +51,15 @@ struct OptimizationSystem <: AbstractTimeIndependentSystem
4651
"""
4752
complete::Bool
4853

49-
function OptimizationSystem(op, states, ps, var_to_name, observed,
54+
function OptimizationSystem(tag, op, states, ps, var_to_name, observed,
5055
constraints, name, systems, defaults, metadata = nothing,
5156
complete = false; checks::Union{Bool, Int} = true)
5257
if checks == true || (checks & CheckUnits) > 0
5358
unwrap(op) isa Symbolic && check_units(op)
5459
check_units(observed)
5560
all_dimensionless([states; ps]) || check_units(constraints)
5661
end
57-
new(op, states, ps, var_to_name, observed,
62+
new(tag, op, states, ps, var_to_name, observed,
5863
constraints, name, systems, defaults, metadata, complete)
5964
end
6065
end
@@ -92,7 +97,8 @@ function OptimizationSystem(op, states, ps;
9297
process_variables!(var_to_name, defaults, ps′)
9398
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
9499

95-
OptimizationSystem(value(op), states′, ps′, var_to_name,
100+
OptimizationSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
101+
value(op), states′, ps′, var_to_name,
96102
observed,
97103
constraints,
98104
name, systems, defaults, metadata; checks = checks)

0 commit comments

Comments
 (0)