Skip to content

Commit 6286691

Browse files
authored
Merge pull request #1125 from SciML/myb/arrayvar
`getproperty` for array variables
2 parents 549885d + 66df11f commit 6286691

File tree

11 files changed

+140
-45
lines changed

11 files changed

+140
-45
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ using Reexport
4343
@reexport using Symbolics
4444
export @derivatives
4545
using Symbolics: _parse_vars, value, @derivatives, get_variables,
46-
exprs_occur_in, solve_for, build_expr
46+
exprs_occur_in, solve_for, build_expr, unwrap, wrap
4747
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
4848
jacobian_sparsity, islinear, _iszero, _isone,
4949
tosymbol, lower_varname, diff2term, var_from_nested_derivative,

src/systems/abstractsystem.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ for prop in [
152152
:iv
153153
:states
154154
:ps
155+
:var_to_name
155156
:ctrls
156157
:defaults
157158
:observed
@@ -201,6 +202,10 @@ Setfield.get(obj::AbstractSystem, ::Setfield.PropertyLens{field}) where {field}
201202
end
202203

203204
rename(x::AbstractSystem, name) = @set x.name = name
205+
function rename(xx::Symbolics.ArrayOp, name)
206+
@set! xx.expr.f.arguments[1] = rename(xx.expr.f.arguments[1], name)
207+
@set! xx.term.arguments[2] = rename(xx.term.arguments[2], name)
208+
end
204209

205210
function Base.propertynames(sys::AbstractSystem; private=false)
206211
if private
@@ -233,25 +238,34 @@ function getvar(sys::AbstractSystem, name::Symbol; namespace=false)
233238
elseif !isempty(systems)
234239
i = findfirst(x->nameof(x)==name, systems)
235240
if i !== nothing
236-
return namespace ? rename(systems[i], renamespace(sysname,name)) : systems[i]
241+
return namespace ? rename(systems[i], renamespace(sysname, name)) : systems[i]
237242
end
238243
end
239244

240-
sts = get_states(sys)
241-
i = findfirst(x->getname(x) == name, sts)
242-
243-
if i !== nothing
244-
return namespace ? renamespace(sysname,sts[i]) : sts[i]
245-
end
245+
if has_var_to_name(sys)
246+
avs = get_var_to_name(sys)
247+
v = get(avs, name, nothing)
248+
v === nothing || return namespace ? renamespace(sysname, v, name) : v
246249

247-
if has_ps(sys)
248-
ps = get_ps(sys)
249-
i = findfirst(x->getname(x) == name,ps)
250+
else
251+
sts = get_states(sys)
252+
i = findfirst(x->getname(x) == name, sts)
250253
if i !== nothing
251-
return namespace ? renamespace(sysname,ps[i]) : ps[i]
254+
return namespace ? renamespace(sysname,sts[i]) : sts[i]
255+
end
256+
257+
if has_ps(sys)
258+
ps = get_ps(sys)
259+
i = findfirst(x->getname(x) == name,ps)
260+
if i !== nothing
261+
return namespace ? renamespace(sysname,ps[i]) : ps[i]
262+
end
252263
end
253264
end
254265

266+
sts = get_states(sys)
267+
i = findfirst(x->getname(x) == name, sts)
268+
255269
if has_observed(sys)
256270
obs = get_observed(sys)
257271
i = findfirst(x->getname(x.lhs)==name,obs)
@@ -296,13 +310,12 @@ ParentScope(sym::Union{Num, Symbolic}) = setmetadata(sym, SymScope, ParentScope(
296310
struct GlobalScope <: SymScope end
297311
GlobalScope(sym::Union{Num, Symbolic}) = setmetadata(sym, SymScope, GlobalScope())
298312

299-
function renamespace(namespace, x)
300-
if x isa Num
301-
renamespace(namespace, value(x))
302-
elseif x isa Symbolic
313+
function renamespace(namespace, x, name=nothing)
314+
x = unwrap(x)
315+
if x isa Symbolic
303316
let scope = getmetadata(x, SymScope, LocalScope())
304317
if scope isa LocalScope
305-
rename(x, renamespace(namespace, getname(x)))
318+
rename(x, renamespace(namespace, name === nothing ? getname(x) : name))
306319
elseif scope isa ParentScope
307320
setmetadata(x, SymScope, scope.parent)
308321
else # GlobalScope

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct ODESystem <: AbstractODESystem
3131
states::Vector
3232
"""Parameter variables. Must not contain the independent variable."""
3333
ps::Vector
34+
"""Array variables."""
35+
var_to_name
3436
"""Control parameters (some subset of `ps`)."""
3537
ctrls::Vector
3638
"""Observed states."""
@@ -82,11 +84,11 @@ struct ODESystem <: AbstractODESystem
8284
"""
8385
connection_type::Any
8486

85-
function ODESystem(deqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
87+
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
8688
check_variables(dvs,iv)
8789
check_parameters(ps,iv)
8890
check_equations(deqs,iv)
89-
new(deqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
91+
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
9092
end
9193
end
9294

@@ -119,8 +121,9 @@ function ODESystem(
119121
dvs′ = value.(scalarize(dvs))
120122
ps′ = value.(scalarize(ps))
121123

122-
collect_defaults!(defaults, dvs′)
123-
collect_defaults!(defaults, ps′)
124+
var_to_name = Dict()
125+
process_variables!(var_to_name, defaults, dvs′)
126+
process_variables!(var_to_name, defaults, ps′)
124127

125128
tgrad = RefValue(Vector{Num}(undef, 0))
126129
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
@@ -131,7 +134,7 @@ function ODESystem(
131134
if length(unique(sysnames)) != length(sysnames)
132135
throw(ArgumentError("System names must be unique."))
133136
end
134-
ODESystem(deqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
137+
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
135138
end
136139

137140
function ODESystem(eqs, iv=nothing; kwargs...)

src/systems/diffeqs/sdesystem.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ struct SDESystem <: AbstractODESystem
3737
states::Vector
3838
"""Parameter variables. Must not contain the independent variable."""
3939
ps::Vector
40+
"""Array variables."""
41+
var_to_name
4042
"""Control parameters (some subset of `ps`)."""
4143
ctrls::Vector
4244
"""Observed states."""
@@ -84,11 +86,11 @@ struct SDESystem <: AbstractODESystem
8486
"""
8587
connection_type::Any
8688

87-
function SDESystem(deqs, neqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
89+
function SDESystem(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
8890
check_variables(dvs,iv)
8991
check_parameters(ps,iv)
9092
check_equations(deqs,iv)
91-
new(deqs, neqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
93+
new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
9294
end
9395
end
9496

@@ -118,15 +120,16 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
118120
defaults = todict(defaults)
119121
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
120122

121-
collect_defaults!(defaults, dvs′)
122-
collect_defaults!(defaults, ps′)
123+
var_to_name = Dict()
124+
process_variables!(var_to_name, defaults, dvs′)
125+
process_variables!(var_to_name, defaults, ps′)
123126

124127
tgrad = RefValue(Vector{Num}(undef, 0))
125128
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
126129
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
127130
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
128131
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
129-
SDESystem(deqs, neqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
132+
SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
130133
end
131134

132135
function generate_diffusion_function(sys::SDESystem, dvs = states(sys), ps = parameters(sys); kwargs...)

src/systems/discrete_system/discrete_system.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ struct DiscreteSystem <: AbstractSystem
3030
states::Vector
3131
"""Parameter variables. Must not contain the independent variable."""
3232
ps::Vector
33+
"""Array variables."""
34+
var_to_name
3335
"""Control parameters (some subset of `ps`)."""
3436
ctrls::Vector
3537
"""Observed states."""
@@ -52,10 +54,10 @@ struct DiscreteSystem <: AbstractSystem
5254
in `DiscreteSystem`.
5355
"""
5456
default_p::Dict
55-
function DiscreteSystem(discreteEqs, iv, dvs, ps, ctrls, observed, name, systems, default_u0, default_p)
57+
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, default_u0, default_p)
5658
check_variables(dvs,iv)
5759
check_parameters(ps,iv)
58-
new(discreteEqs, iv, dvs, ps, ctrls, observed, name, systems, default_u0, default_p)
60+
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, default_u0, default_p)
5961
end
6062
end
6163

@@ -86,14 +88,15 @@ function DiscreteSystem(
8688
defaults = todict(defaults)
8789
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
8890

89-
collect_defaults!(defaults, dvs′)
90-
collect_defaults!(defaults, ps′)
91+
var_to_name = Dict()
92+
process_variables!(var_to_name, defaults, dvs′)
93+
process_variables!(var_to_name, defaults, ps′)
9194

9295
sysnames = nameof.(systems)
9396
if length(unique(sysnames)) != length(sysnames)
9497
throw(ArgumentError("System names must be unique."))
9598
end
96-
DiscreteSystem(eqs, iv′, dvs′, ps′, ctrl′, observed, name, systems, default_u0, default_p)
99+
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, default_u0, default_p)
97100
end
98101

99102
"""

src/systems/jumps/jumpsystem.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
3737
states::Vector
3838
"""The parameters of the system. Must not contain the independent variable."""
3939
ps::Vector
40+
"""Array variables."""
41+
var_to_name
4042
observed::Vector{Equation}
4143
"""The name of the system. . These are required to have unique names."""
4244
name::Symbol
@@ -51,10 +53,10 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
5153
type: type of the system
5254
"""
5355
connection_type::Any
54-
function JumpSystem{U}(ap::U, iv, states, ps, observed, name, systems, defaults, connection_type) where U <: ArrayPartition
56+
function JumpSystem{U}(ap::U, iv, states, ps, var_to_name, observed, name, systems, defaults, connection_type) where U <: ArrayPartition
5557
check_variables(states, iv)
5658
check_parameters(ps, iv)
57-
new{U}(ap, iv, states, ps, observed, name, systems, defaults, connection_type)
59+
new{U}(ap, iv, states, ps, var_to_name, observed, name, systems, defaults, connection_type)
5860
end
5961
end
6062

@@ -92,10 +94,11 @@ function JumpSystem(eqs, iv, states, ps;
9294
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
9395

9496
states, ps = value.(states), value.(ps)
95-
collect_defaults!(defaults, states)
96-
collect_defaults!(defaults, ps)
97+
var_to_name = Dict()
98+
process_variables!(var_to_name, defaults, states)
99+
process_variables!(var_to_name, defaults, ps)
97100

98-
JumpSystem{typeof(ap)}(ap, value(iv), states, ps, observed, name, systems, defaults, connection_type)
101+
JumpSystem{typeof(ap)}(ap, value(iv), states, ps, var_to_name, observed, name, systems, defaults, connection_type)
99102
end
100103

101104
function generate_rate_function(js, rate)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ struct NonlinearSystem <: AbstractSystem
2525
states::Vector
2626
"""Parameters."""
2727
ps::Vector
28+
"""Array variables."""
29+
var_to_name
2830
observed::Vector{Equation}
2931
"""
3032
Jacobian matrix. Note: this field will not be defined until
@@ -76,10 +78,11 @@ function NonlinearSystem(eqs, states, ps;
7678
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
7779

7880
states, ps = value.(states), value.(ps)
79-
collect_defaults!(defaults, states)
80-
collect_defaults!(defaults, ps)
81+
var_to_name = Dict()
82+
process_variables!(var_to_name, defaults, states)
83+
process_variables!(var_to_name, defaults, ps)
8184

82-
NonlinearSystem(eqs, states, ps, observed, jac, name, systems, defaults, nothing, connection_type)
85+
NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, nothing, connection_type)
8386
end
8487

8588
function calculate_jacobian(sys::NonlinearSystem; sparse=false, simplify=false)

src/systems/optimization/optimizationsystem.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ struct OptimizationSystem <: AbstractSystem
2323
states::Vector
2424
"""Parameters."""
2525
ps::Vector
26+
"""Array variables."""
27+
var_to_name
2628
observed::Vector{Equation}
2729
equality_constraints::Vector{Equation}
2830
inequality_constraints::Vector
@@ -61,11 +63,12 @@ function OptimizationSystem(op, states, ps;
6163
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
6264

6365
states, ps = value.(states), value.(ps)
64-
collect_defaults!(defaults, states)
65-
collect_defaults!(defaults, ps)
66+
var_to_name = Dict()
67+
process_variables!(var_to_name, defaults, states)
68+
process_variables!(var_to_name, defaults, ps)
6669

6770
OptimizationSystem(
68-
value(op), states, ps,
71+
value(op), states, ps, var_to_name,
6972
observed,
7073
equality_constraints, inequality_constraints,
7174
name, systems, defaults

src/utils.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,66 @@ hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue)
160160
getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue))
161161
setdefault(v, val) = val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val))
162162

163+
function process_variables!(var_to_name, defs, vars)
164+
collect_defaults!(defs, vars)
165+
collect_var_to_name!(var_to_name, vars)
166+
return nothing
167+
end
168+
163169
function collect_defaults!(defs, vars)
164170
for v in vars; (haskey(defs, v) || !hasdefault(v)) && continue
165171
defs[v] = getdefault(v)
166172
end
167173
return defs
168174
end
169175

176+
function var_to_name(x, vars=Dict{Symbol, Any}())
177+
x = Symbolics.unwrap(x)
178+
if istree(x)
179+
if hasmetadata(x, Symbolics.GetindexParent)
180+
v = Dict{Symbol, Any}()
181+
foreach(a->var_to_name(a, v), arguments(x))
182+
var_to_name(operation(x), v)
183+
name = first(only(v))
184+
vars[name] = getmetadata(x, Symbolics.GetindexParent)
185+
elseif x isa Symbolics.ArrayOp
186+
t = x.term
187+
if istree(t) && operation(t) === (map) && arguments(t)[1] isa Symbolics.CallWith
188+
vars[nameof(arguments(t)[2])] = x
189+
else
190+
var_to_name(x.expr, vars)
191+
end
192+
else
193+
var_to_name(operation(x), vars)
194+
for a in arguments(x)
195+
var_to_name(a, vars)
196+
end
197+
end
198+
elseif x isa Sym && symtype(x) <: AbstractArray
199+
vars[nameof(x)] = x
200+
end
201+
202+
vars
203+
end
204+
205+
function collect_var_to_name!(vars, xs)
206+
for x in xs
207+
ax = var_to_name(x)
208+
if isempty(ax)
209+
vars[getname(x)] = x
210+
else
211+
merge!(vars, ax)
212+
end
213+
end
214+
return vars
215+
end
216+
170217
"Throw error when difference/derivative operation occurs in the R.H.S."
171218
@noinline function throw_invalid_operator(opvar, eq, op::Type)
172219
if op === Difference
173220
optext = "difference"
174221
elseif op === Differential
175-
optext="derivative"
222+
optext="derivative"
176223
end
177224
msg = "The $optext variable must be isolated to the left-hand " *
178225
"side of the equation like `$opvar ~ ...`.\n Got $eq."

test/controlsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ sol = solve(prob,BFGS())
2828
sys1 = ControlSystem(loss,eqs_short, t, [x, v], [u], p, name = :sys1)
2929
sys2 = ControlSystem(loss,eqs_short, t, [x, v], [u], p, name = :sys1)
3030
@test_throws ArgumentError ControlSystem(loss, [sys2.v ~ sys1.v], t, [], [], [], systems = [sys1, sys2])
31-
end
31+
end

0 commit comments

Comments
 (0)