Skip to content

Commit 61fd2ce

Browse files
committed
Many fixes
1 parent c190ed8 commit 61fd2ce

File tree

9 files changed

+61
-70
lines changed

9 files changed

+61
-70
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ using Reexport
4242
@reexport using Symbolics
4343
export @derivatives
4444
using Symbolics: _parse_vars, value, @derivatives, get_variables,
45-
exprs_occur_in, solve_for, build_expr
45+
exprs_occur_in, solve_for, build_expr, unwrap, wrap
4646
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
4747
jacobian_sparsity, islinear, _iszero, _isone,
4848
tosymbol, lower_varname, diff2term, var_from_nested_derivative,

src/systems/abstractsystem.jl

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ for prop in [
152152
:iv
153153
:states
154154
:ps
155-
:array_vars
155+
:var_to_name
156156
:ctrls
157157
:defaults
158158
:observed
@@ -202,6 +202,10 @@ Setfield.get(obj::AbstractSystem, ::Setfield.PropertyLens{field}) where {field}
202202
end
203203

204204
rename(x::AbstractSystem, name) = @set x.name = name
205+
function rename(xx::Symbolics.ArrayOp, name)
206+
@set! xx.expr.f.arguments[1] = rename(xx.expr.f.arguments[1], name)
207+
@set! xx.term.arguments[2] = rename(xx.term.arguments[2], name)
208+
end
205209

206210
function Base.propertynames(sys::AbstractSystem; private=false)
207211
if private
@@ -234,29 +238,17 @@ function getvar(sys::AbstractSystem, name::Symbol; namespace=false)
234238
elseif !isempty(systems)
235239
i = findfirst(x->nameof(x)==name, systems)
236240
if i !== nothing
237-
return namespace ? rename(systems[i], renamespace(sysname,name)) : systems[i]
241+
return namespace ? rename(systems[i], renamespace(sysname, name)) : systems[i]
238242
end
239243
end
240244

241-
avs = get_array_vars(sys)
245+
avs = get_var_to_name(sys)
242246
v = get(avs, name, nothing)
243-
v === nothing || return namespace ? renamespace(sysname, v) : v
247+
v === nothing || return namespace ? renamespace(sysname, v, name) : v
244248

245249
sts = get_states(sys)
246250
i = findfirst(x->getname(x) == name, sts)
247251

248-
if i !== nothing
249-
return namespace ? renamespace(sysname,sts[i]) : sts[i]
250-
end
251-
252-
if has_ps(sys)
253-
ps = get_ps(sys)
254-
i = findfirst(x->getname(x) == name,ps)
255-
if i !== nothing
256-
return namespace ? renamespace(sysname,ps[i]) : ps[i]
257-
end
258-
end
259-
260252
if has_observed(sys)
261253
obs = get_observed(sys)
262254
i = findfirst(x->getname(x.lhs)==name,obs)
@@ -301,13 +293,12 @@ ParentScope(sym::Union{Num, Symbolic}) = setmetadata(sym, SymScope, ParentScope(
301293
struct GlobalScope <: SymScope end
302294
GlobalScope(sym::Union{Num, Symbolic}) = setmetadata(sym, SymScope, GlobalScope())
303295

304-
function renamespace(namespace, x)
305-
if x isa Num
306-
renamespace(namespace, value(x))
307-
elseif x isa Symbolic
296+
function renamespace(namespace, x, name=nothing)
297+
x = unwrap(x)
298+
if x isa Symbolic
308299
let scope = getmetadata(x, SymScope, LocalScope())
309300
if scope isa LocalScope
310-
rename(x, renamespace(namespace, getname(x)))
301+
rename(x, renamespace(namespace, name === nothing ? getname(x) : name))
311302
elseif scope isa ParentScope
312303
setmetadata(x, SymScope, scope.parent)
313304
else # GlobalScope

src/systems/diffeqs/odesystem.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct ODESystem <: AbstractODESystem
3232
"""Parameter variables. Must not contain the independent variable."""
3333
ps::Vector
3434
"""Array variables."""
35-
array_vars
35+
var_to_name
3636
"""Control parameters (some subset of `ps`)."""
3737
ctrls::Vector
3838
"""Observed states."""
@@ -84,11 +84,11 @@ struct ODESystem <: AbstractODESystem
8484
"""
8585
connection_type::Any
8686

87-
function ODESystem(deqs, iv, dvs, ps, array_vars, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
87+
function ODESystem(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
8888
check_variables(dvs,iv)
8989
check_parameters(ps,iv)
9090
check_equations(deqs,iv)
91-
new(deqs, iv, dvs, ps, array_vars, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
91+
new(deqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
9292
end
9393
end
9494

@@ -121,9 +121,9 @@ function ODESystem(
121121
dvs′ = value.(scalarize(dvs))
122122
ps′ = value.(scalarize(ps))
123123

124-
array_vars = Dict()
125-
process_variables!(array_vars, defaults, dvs′)
126-
process_variables!(array_vars, defaults, ps′)
124+
var_to_name = Dict()
125+
process_variables!(var_to_name, defaults, dvs′)
126+
process_variables!(var_to_name, defaults, ps′)
127127

128128
tgrad = RefValue(Vector{Num}(undef, 0))
129129
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
@@ -134,7 +134,7 @@ function ODESystem(
134134
if length(unique(sysnames)) != length(sysnames)
135135
throw(ArgumentError("System names must be unique."))
136136
end
137-
ODESystem(deqs, iv′, dvs′, ps′, array_vars, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
137+
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
138138
end
139139

140140
vars(x::Sym) = Set([x])

src/systems/diffeqs/sdesystem.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct SDESystem <: AbstractODESystem
3838
"""Parameter variables. Must not contain the independent variable."""
3939
ps::Vector
4040
"""Array variables."""
41-
array_vars
41+
var_to_name
4242
"""Control parameters (some subset of `ps`)."""
4343
ctrls::Vector
4444
"""Observed states."""
@@ -86,11 +86,11 @@ struct SDESystem <: AbstractODESystem
8686
"""
8787
connection_type::Any
8888

89-
function SDESystem(deqs, neqs, iv, dvs, ps, array_vars, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
89+
function SDESystem(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
9090
check_variables(dvs,iv)
9191
check_parameters(ps,iv)
9292
check_equations(deqs,iv)
93-
new(deqs, neqs, iv, dvs, ps, array_vars, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
93+
new(deqs, neqs, iv, dvs, ps, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
9494
end
9595
end
9696

@@ -120,16 +120,16 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
120120
defaults = todict(defaults)
121121
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
122122

123-
array_vars = Dict()
124-
process_variables!(array_vars, defaults, dvs′)
125-
process_variables!(array_vars, defaults, ps′)
123+
var_to_name = Dict()
124+
process_variables!(var_to_name, defaults, dvs′)
125+
process_variables!(var_to_name, defaults, ps′)
126126

127127
tgrad = RefValue(Vector{Num}(undef, 0))
128128
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
129129
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
130130
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
131131
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
132-
SDESystem(deqs, neqs, iv′, dvs′, ps′, array_vars, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
132+
SDESystem(deqs, neqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
133133
end
134134

135135
function generate_diffusion_function(sys::SDESystem, dvs = states(sys), ps = parameters(sys); kwargs...)

src/systems/discrete_system/discrete_system.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct DiscreteSystem <: AbstractSystem
3131
"""Parameter variables. Must not contain the independent variable."""
3232
ps::Vector
3333
"""Array variables."""
34-
array_vars
34+
var_to_name
3535
"""Control parameters (some subset of `ps`)."""
3636
ctrls::Vector
3737
"""Observed states."""
@@ -54,10 +54,10 @@ struct DiscreteSystem <: AbstractSystem
5454
in `DiscreteSystem`.
5555
"""
5656
default_p::Dict
57-
function DiscreteSystem(discreteEqs, iv, dvs, ps, array_vars, ctrls, observed, name, systems, default_u0, default_p)
57+
function DiscreteSystem(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, default_u0, default_p)
5858
check_variables(dvs,iv)
5959
check_parameters(ps,iv)
60-
new(discreteEqs, iv, dvs, ps, array_vars, ctrls, observed, name, systems, default_u0, default_p)
60+
new(discreteEqs, iv, dvs, ps, var_to_name, ctrls, observed, name, systems, default_u0, default_p)
6161
end
6262
end
6363

@@ -88,15 +88,15 @@ function DiscreteSystem(
8888
defaults = todict(defaults)
8989
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
9090

91-
array_vars = Dict()
92-
process_variables!(array_vars, defaults, dvs′)
93-
process_variables!(array_vars, defaults, ps′)
91+
var_to_name = Dict()
92+
process_variables!(var_to_name, defaults, dvs′)
93+
process_variables!(var_to_name, defaults, ps′)
9494

9595
sysnames = nameof.(systems)
9696
if length(unique(sysnames)) != length(sysnames)
9797
throw(ArgumentError("System names must be unique."))
9898
end
99-
DiscreteSystem(eqs, iv′, dvs′, ps′, array_vars, ctrl′, observed, name, systems, default_u0, default_p)
99+
DiscreteSystem(eqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, name, systems, default_u0, default_p)
100100
end
101101

102102
"""

src/systems/jumps/jumpsystem.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
3838
"""The parameters of the system. Must not contain the independent variable."""
3939
ps::Vector
4040
"""Array variables."""
41-
array_vars
41+
var_to_name
4242
observed::Vector{Equation}
4343
"""The name of the system. . These are required to have unique names."""
4444
name::Symbol
@@ -53,10 +53,10 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
5353
type: type of the system
5454
"""
5555
connection_type::Any
56-
function JumpSystem{U}(ap::U, iv, states, ps, array_vars, observed, name, systems, defaults, connection_type) where U <: ArrayPartition
56+
function JumpSystem{U}(ap::U, iv, states, ps, var_to_name, observed, name, systems, defaults, connection_type) where U <: ArrayPartition
5757
check_variables(states, iv)
5858
check_parameters(ps, iv)
59-
new{U}(ap, iv, states, ps, array_vars, observed, name, systems, defaults, connection_type)
59+
new{U}(ap, iv, states, ps, var_to_name, observed, name, systems, defaults, connection_type)
6060
end
6161
end
6262

@@ -94,11 +94,11 @@ function JumpSystem(eqs, iv, states, ps;
9494
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
9595

9696
states, ps = value.(states), value.(ps)
97-
array_vars = Dict()
98-
process_variables!(array_vars, defaults, dvs′)
99-
process_variables!(array_vars, defaults, ps′)
97+
var_to_name = Dict()
98+
process_variables!(var_to_name, defaults, dvs′)
99+
process_variables!(var_to_name, defaults, ps′)
100100

101-
JumpSystem{typeof(ap)}(ap, value(iv), states, ps, array_vars, observed, name, systems, defaults, connection_type)
101+
JumpSystem{typeof(ap)}(ap, value(iv), states, ps, var_to_name, observed, name, systems, defaults, connection_type)
102102
end
103103

104104
function generate_rate_function(js, rate)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct NonlinearSystem <: AbstractSystem
2626
"""Parameters."""
2727
ps::Vector
2828
"""Array variables."""
29-
array_vars
29+
var_to_name
3030
observed::Vector{Equation}
3131
"""
3232
Jacobian matrix. Note: this field will not be defined until
@@ -78,11 +78,11 @@ function NonlinearSystem(eqs, states, ps;
7878
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
7979

8080
states, ps = value.(states), value.(ps)
81-
array_vars = Dict()
82-
process_variables!(array_vars, defaults, dvs′)
83-
process_variables!(array_vars, defaults, ps)
81+
var_to_name = Dict()
82+
process_variables!(var_to_name, defaults, states)
83+
process_variables!(var_to_name, defaults, ps)
8484

85-
NonlinearSystem(eqs, states, ps, array_vars, observed, jac, name, systems, defaults, nothing, connection_type)
85+
NonlinearSystem(eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, nothing, connection_type)
8686
end
8787

8888
function calculate_jacobian(sys::NonlinearSystem; sparse=false, simplify=false)

src/systems/optimization/optimizationsystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct OptimizationSystem <: AbstractSystem
2424
"""Parameters."""
2525
ps::Vector
2626
"""Array variables."""
27-
array_vars
27+
var_to_name
2828
observed::Vector{Equation}
2929
equality_constraints::Vector{Equation}
3030
inequality_constraints::Vector
@@ -63,12 +63,12 @@ function OptimizationSystem(op, states, ps;
6363
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
6464

6565
states, ps = value.(states), value.(ps)
66-
array_vars = Dict()
67-
process_variables!(array_vars, defaults, dvs′)
68-
process_variables!(array_vars, defaults, ps)
66+
var_to_name = Dict()
67+
process_variables!(var_to_name, defaults, states)
68+
process_variables!(var_to_name, defaults, ps)
6969

7070
OptimizationSystem(
71-
value(op), states, ps, array_vars,
71+
value(op), states, ps, var_to_name,
7272
observed,
7373
equality_constraints, inequality_constraints,
7474
name, systems, defaults

src/utils.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ hasdefault(v) = hasmetadata(v, Symbolics.VariableDefaultValue)
160160
getdefault(v) = value(getmetadata(v, Symbolics.VariableDefaultValue))
161161
setdefault(v, val) = val === nothing ? v : setmetadata(v, Symbolics.VariableDefaultValue, value(val))
162162

163-
function process_variables!(array_vars, defs, vars)
163+
function process_variables!(var_to_name, defs, vars)
164164
collect_defaults!(defs, vars)
165-
collect_array_vars!(array_vars, vars)
165+
collect_var_to_name!(var_to_name, vars)
166166
return nothing
167167
end
168168

@@ -173,26 +173,26 @@ function collect_defaults!(defs, vars)
173173
return defs
174174
end
175175

176-
function array_vars(x, vars=Dict{Symbol, Any}())
176+
function var_to_name(x, vars=Dict{Symbol, Any}())
177177
x = Symbolics.unwrap(x)
178178
if istree(x)
179179
if hasmetadata(x, Symbolics.GetindexParent)
180180
v = Dict{Symbol, Any}()
181-
foreach(a->array_vars(a, v), arguments(x))
182-
array_vars(operation(x), v)
181+
foreach(a->var_to_name(a, v), arguments(x))
182+
var_to_name(operation(x), v)
183183
name = first(only(v))
184184
vars[name] = getmetadata(x, Symbolics.GetindexParent)
185185
elseif x isa Symbolics.ArrayOp
186186
t = x.term
187187
if istree(t) && operation(t) === (map) && arguments(t)[1] isa Symbolics.CallWith
188188
vars[nameof(arguments(t)[2])] = x
189189
else
190-
array_vars(x.expr, vars)
190+
var_to_name(x.expr, vars)
191191
end
192192
else
193-
array_vars(operation(x), vars)
193+
var_to_name(operation(x), vars)
194194
for a in arguments(x)
195-
array_vars(a, vars)
195+
var_to_name(a, vars)
196196
end
197197
end
198198
elseif x isa Sym && symtype(x) <: AbstractArray
@@ -202,9 +202,9 @@ function array_vars(x, vars=Dict{Symbol, Any}())
202202
vars
203203
end
204204

205-
function collect_array_vars!(vars, xs)
205+
function collect_var_to_name!(vars, xs)
206206
for x in xs
207-
ax = array_vars(x)
207+
ax = var_to_name(x)
208208
if isempty(ax)
209209
vars[getname(x)] = x
210210
else

0 commit comments

Comments
 (0)