Skip to content

Commit 755ebdf

Browse files
authored
Merge pull request #760 from SciML/myb/getvar
Refactoring and enable NonlinearSystem nesting
2 parents a340758 + d8a9ea3 commit 755ebdf

23 files changed

+384
-401
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2424
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2525
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
2626
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
27+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2728
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2829
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2930
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -51,10 +52,11 @@ RecursiveArrayTools = "2.3"
5152
Requires = "1.0"
5253
RuntimeGeneratedFunctions = "0.4, 0.5"
5354
SafeTestsets = "0.0.1"
55+
SciMLBase = "1.3"
5456
Setfield = "0.7"
5557
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
5658
StaticArrays = "0.10, 0.11, 0.12, 1.0"
57-
SymbolicUtils = "0.7.4"
59+
SymbolicUtils = "0.7.4, 0.8"
5860
TreeViews = "0.3"
5961
UnPack = "0.1, 1.0"
6062
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
2323
using RecursiveArrayTools
2424

2525
import SymbolicUtils
26-
import SymbolicUtils: Term, Add, Mul, Pow, Sym, to_symbolic, FnType,
26+
import SymbolicUtils: Term, Add, Mul, Pow, Sym, FnType,
2727
@rule, Rewriters, substitute, similarterm,
2828
promote_symtype
2929

@@ -55,8 +55,6 @@ value(x) = x
5555
value(x::Num) = x.val
5656

5757

58-
using SymbolicUtils: to_symbolic
59-
SymbolicUtils.to_symbolic(n::Num) = value(n)
6058
SymbolicUtils.@number_methods(Num,
6159
Num(f(value(a))),
6260
Num(f(value(a), value(b))))
@@ -266,7 +264,7 @@ export Differential, expand_derivatives, @derivatives
266264
export IntervalDomain, ProductDomain, , CircleDomain
267265
export Equation, ConstrainedEquation
268266
export Term, Sym
269-
export independent_variable, states, parameters, equations, controls, pins, observed
267+
export independent_variable, states, parameters, equations, controls, observed, structure
270268

271269
export calculate_jacobian, generate_jacobian, generate_function
272270
export calculate_tgrad, generate_tgrad
@@ -275,6 +273,7 @@ export calculate_factorized_W, generate_factorized_W
275273
export calculate_hessian, generate_hessian
276274
export calculate_massmatrix, generate_diffusion_function
277275
export stochastic_integral_transform
276+
export initialize_system_structure
278277

279278
export BipartiteGraph, equation_dependencies, variable_dependencies
280279
export eqeq_dependencies, varvar_dependencies

src/build_function.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function unflatten_long_ops(op, N=4)
9393
rule2 = @rule((*)((~~x)) => length(~~x) > N ?
9494
*(*((~~x)[1:N]...) * (*)((~~x)[N+1:end]...)) : nothing)
9595

96-
op = to_symbolic(op)
96+
op = value(op)
9797
Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2])))(op)
9898
end
9999

src/differentials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Base.:*(D1::Differential, D2) = D1 ∘ D2
4242
Base.:*(D1::Differential, D2::Differential) = D1 D2
4343
Base.:^(D::Differential, n::Integer) = _repeat_apply(D, n)
4444

45-
Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")")
45+
Base.show(io::IO, D::Differential) = print(io, "Differential(", D.x, ")")
4646

4747
Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)
4848

src/systems/abstractsystem.jl

Lines changed: 91 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ Generate a function to evaluate the system's equations.
117117
"""
118118
function generate_function end
119119

120+
Base.nameof(sys::AbstractSystem) = getfield(sys, :name)
121+
120122
function getname(t)
121123
if istree(t)
122124
operation(t) isa Sym ? getname(operation(t)) : error("Cannot get name of $t")
@@ -125,34 +127,54 @@ function getname(t)
125127
end
126128
end
127129

128-
function Base.getproperty(sys::AbstractSystem, name::Symbol)
130+
independent_variable(sys::AbstractSystem) = getfield(sys, :iv)
129131

130-
if name fieldnames(typeof(sys))
131-
return getfield(sys,name)
132-
elseif !isempty(sys.systems)
133-
i = findfirst(x->x.name==name,sys.systems)
132+
function structure(sys::AbstractSystem)
133+
s = get_structure(sys)
134+
s isa SystemStructure || throw(ArgumentError("SystemStructure is not yet initialized, please run `sys = initialize_system_structure(sys)` or `sys = alias_elimination(sys)`."))
135+
return s
136+
end
137+
138+
for prop in [:eqs, :iv, :states, :ps, :default_p, :default_u0, :observed, :tgrad, :jac, :Wfact, :Wfact_t, :systems, :structure]
139+
fname = Symbol(:get_, prop)
140+
@eval begin
141+
$fname(sys::AbstractSystem) = getfield(sys, $(QuoteNode(prop)))
142+
end
143+
end
144+
145+
function Base.getproperty(sys::AbstractSystem, name::Symbol)
146+
sysname = nameof(sys)
147+
systems = get_systems(sys)
148+
if isdefined(sys, name)
149+
Base.depwarn("`sys.name` like `sys.$name` is deprecated. Use getters like `get_$name` instead.", "sys.$name")
150+
return getfield(sys, name)
151+
elseif !isempty(systems)
152+
i = findfirst(x->nameof(x)==name,systems)
134153
if i !== nothing
135-
return rename(sys.systems[i],renamespace(sys.name,name))
154+
return rename(systems[i],renamespace(sysname,name))
136155
end
137156
end
138157

139-
i = findfirst(x->getname(x) == name, sys.states)
158+
sts = get_states(sys)
159+
i = findfirst(x->getname(x) == name, sts)
140160

141161
if i !== nothing
142-
return rename(sys.states[i],renamespace(sys.name,name))
162+
return rename(sts[i],renamespace(sysname,name))
143163
end
144164

145-
if :ps fieldnames(typeof(sys))
146-
i = findfirst(x->getname(x) == name,sys.ps)
165+
if isdefined(sys, :ps)
166+
ps = get_ps(sys)
167+
i = findfirst(x->getname(x) == name,ps)
147168
if i !== nothing
148-
return rename(sys.ps[i],renamespace(sys.name,name))
169+
return rename(ps[i],renamespace(sysname,name))
149170
end
150171
end
151172

152-
if :observed fieldnames(typeof(sys))
153-
i = findfirst(x->getname(x.lhs)==name,sys.observed)
173+
if isdefined(sys, :observed)
174+
obs = get_observed(sys)
175+
i = findfirst(x->getname(x.lhs)==name,obs)
154176
if i !== nothing
155-
return rename(sys.observed[i].lhs,renamespace(sys.name,name))
177+
return rename(obs[i].lhs,renamespace(sysname,name))
156178
end
157179
end
158180

@@ -171,75 +193,89 @@ function renamespace(namespace, x)
171193
end
172194
end
173195

174-
function namespace_variables(sys::AbstractSystem)
175-
[renamespace(sys.name,x) for x in states(sys)]
176-
end
196+
namespace_variables(sys::AbstractSystem) = states(sys, states(sys))
197+
namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
177198

178-
function namespace_parameters(sys::AbstractSystem)
179-
[toparam(renamespace(sys.name,x)) for x in parameters(sys)]
199+
function namespace_default_u0(sys)
200+
d_u0 = default_u0(sys)
201+
Dict(states(sys, k) => d_u0[k] for k in keys(d_u0))
180202
end
181203

182-
function namespace_pins(sys::AbstractSystem)
183-
[renamespace(sys.name,x) for x in pins(sys)]
204+
function namespace_default_p(sys)
205+
d_p = default_p(sys)
206+
Dict(parameters(sys, k) => d_p[k] for k in keys(d_p))
184207
end
185208

186209
function namespace_equations(sys::AbstractSystem)
187210
eqs = equations(sys)
188211
isempty(eqs) && return Equation[]
189-
map(eq->namespace_equation(eq,sys.name,sys.iv.name), eqs)
212+
iv = independent_variable(sys)
213+
map(eq->namespace_equation(eq,nameof(sys),iv), eqs)
190214
end
191215

192-
function namespace_equation(eq::Equation,name,ivname)
193-
_lhs = namespace_expr(eq.lhs,name,ivname)
194-
_rhs = namespace_expr(eq.rhs,name,ivname)
216+
function namespace_equation(eq::Equation,name,iv)
217+
_lhs = namespace_expr(eq.lhs,name,iv)
218+
_rhs = namespace_expr(eq.rhs,name,iv)
195219
_lhs ~ _rhs
196220
end
197221

198-
function namespace_expr(O::Sym,name,ivname)
199-
O.name == ivname ? O : rename(O,renamespace(name,O.name))
222+
function namespace_expr(O::Sym,name,iv)
223+
isequal(O, iv) ? O : rename(O,renamespace(name,nameof(O)))
200224
end
201225

202226
_symparam(s::Symbolic{T}) where {T} = T
203-
function namespace_expr(O,name,ivname) where {T}
227+
function namespace_expr(O,name,iv) where {T}
204228
if istree(O)
229+
renamed = map(a->namespace_expr(a,name,iv), arguments(O))
205230
if operation(O) isa Sym
206-
Term{_symparam(O)}(rename(operation(O),renamespace(name,operation(O).name)),namespace_expr.(arguments(O),name,ivname))
231+
renamed_op = rename(operation(O),renamespace(name,nameof(operation(O))))
232+
Term{_symparam(O)}(renamed_op,renamed)
207233
else
208-
similarterm(O,operation(O),namespace_expr.(arguments(O),name,ivname))
234+
similarterm(O,operation(O),renamed)
209235
end
210236
else
211237
O
212238
end
213239
end
214240

215-
independent_variable(sys::AbstractSystem) = sys.iv
216241
function states(sys::AbstractSystem)
217-
unique(isempty(sys.systems) ?
218-
sys.states :
219-
[sys.states;reduce(vcat,namespace_variables.(sys.systems))])
242+
sts = get_states(sys)
243+
systems = get_systems(sys)
244+
unique(isempty(systems) ?
245+
sts :
246+
[sts;reduce(vcat,namespace_variables.(systems))])
247+
end
248+
function parameters(sys::AbstractSystem)
249+
ps = get_ps(sys)
250+
systems = get_systems(sys)
251+
isempty(systems) ? ps : [ps;reduce(vcat,namespace_parameters.(systems))]
220252
end
221-
parameters(sys::AbstractSystem) = isempty(sys.systems) ? sys.ps : [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]
222-
pins(sys::AbstractSystem) = isempty(sys.systems) ? sys.pins : [sys.pins;reduce(vcat,namespace_pins.(sys.systems))]
223253
function observed(sys::AbstractSystem)
224-
[sys.observed;
254+
iv = independent_variable(sys)
255+
obs = get_observed(sys)
256+
systems = get_systems(sys)
257+
[obs;
225258
reduce(vcat,
226-
(namespace_equation.(observed(s), s.name, s.iv.name) for s in sys.systems),
259+
(map(o->namespace_equation(o, nameof(s), iv), observed(s)) for s in systems),
227260
init=Equation[])]
228261
end
229262

230-
function states(sys::AbstractSystem,name::Symbol)
231-
x = sys.states[findfirst(x->x.name==name,sys.states)]
232-
rename(x,renamespace(sys.name,x.name))(sys.iv)
263+
function default_u0(sys::AbstractSystem)
264+
systems = get_systems(sys)
265+
d_u0 = get_default_u0(sys)
266+
isempty(systems) ? d_u0 : mapreduce(namespace_default_u0, merge, systems; init=d_u0)
233267
end
234268

235-
function parameters(sys::AbstractSystem,name::Symbol)
236-
x = sys.ps[findfirst(x->x.name==name,sys.ps)]
237-
rename(x,renamespace(sys.name,x.name))()
269+
function default_p(sys::AbstractSystem)
270+
systems = get_systems(sys)
271+
d_p = get_default_p(sys)
272+
isempty(systems) ? d_p : mapreduce(namespace_default_p, merge, systems; init=d_p)
238273
end
239274

240-
function pins(sys::AbstractSystem,name::Symbol)
241-
x = sys.pins[findfirst(x->x.name==name,sys.ps)]
242-
rename(x,renamespace(sys.name,x.name))(sys.iv())
275+
states(sys::AbstractSystem, v) = renamespace(nameof(sys), v)
276+
parameters(sys::AbstractSystem, v) = toparam(states(sys, v))
277+
for f in [:states, :parameters]
278+
@eval $f(sys::AbstractSystem, vs::AbstractArray) = map(v->$f(sys, v), vs)
243279
end
244280

245281
lhss(xs) = map(x->x.lhs, xs)
@@ -248,57 +284,35 @@ rhss(xs) = map(x->x.rhs, xs)
248284
flatten(sys::AbstractSystem) = sys
249285

250286
function equations(sys::ModelingToolkit.AbstractSystem)
251-
if isempty(sys.systems)
252-
return sys.eqs
287+
eqs = get_eqs(sys)
288+
systems = get_systems(sys)
289+
if isempty(systems)
290+
return eqs
253291
else
254-
eqs = Equation[sys.eqs;
292+
eqs = Equation[eqs;
255293
reduce(vcat,
256-
namespace_equations.(sys.systems);
294+
namespace_equations.(get_systems(sys));
257295
init=Equation[])]
258296
return eqs
259297
end
260298
end
261299

262-
function states(sys::AbstractSystem,args...)
263-
name = last(args)
264-
extra_names = reduce(Symbol,[Symbol(:₊,x.name) for x in args[1:end-1]])
265-
newname = renamespace(extra_names,name)
266-
rename(x,renamespace(sys.name,newname))(sys.iv)
267-
end
268-
269-
function parameters(sys::AbstractSystem,args...)
270-
name = last(args)
271-
extra_names = reduce(Symbol,[Symbol(:₊,x.name) for x in args[1:end-1]])
272-
newname = renamespace(extra_names,name)
273-
rename(x,renamespace(sys.name,newname))()
274-
end
275-
276300
function islinear(sys::AbstractSystem)
277301
rhs = [eq.rhs for eq equations(sys)]
278302

279303
all(islinear(r, states(sys)) for r in rhs)
280304
end
281305

282-
function pins(sys::AbstractSystem,args...)
283-
name = last(args)
284-
extra_names = reduce(Symbol,[Symbol(:₊,x.name) for x in args[1:end-1]])
285-
newname = renamespace(extra_names,name)
286-
rename(x,renamespace(sys.name,newname))(sys.iv())
287-
end
288-
289306
struct AbstractSysToExpr
290307
sys::AbstractSystem
291308
states::Vector
292309
end
293310
AbstractSysToExpr(sys) = AbstractSysToExpr(sys,states(sys))
294311
function (f::AbstractSysToExpr)(O)
295312
!istree(O) && return toexpr(O)
296-
any(isequal(O), f.states) && return operation(O).name # variables
313+
any(isequal(O), f.states) && return nameof(operation(O)) # variables
297314
if isa(operation(O), Sym)
298-
return build_expr(:call, Any[operation(O).name; f.(arguments(O))])
315+
return build_expr(:call, Any[nameof(operation(O)); f.(arguments(O))])
299316
end
300317
return build_expr(:call, Any[operation(O); f.(arguments(O))])
301318
end
302-
303-
get_default_p(sys) = sys.default_p
304-
get_default_u0(sys) = sys.default_u0

src/systems/control/controlsystem.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ struct ControlSystem <: AbstractControlSystem
5858
controls::Vector
5959
"""Parameter variables."""
6060
ps::Vector
61-
pins::Vector
6261
observed::Vector{Equation}
6362
"""
6463
Name: the name of the system
@@ -81,7 +80,6 @@ struct ControlSystem <: AbstractControlSystem
8180
end
8281

8382
function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls, ps;
84-
pins = [],
8583
observed = [],
8684
systems = ODESystem[],
8785
default_u0=Dict(),
@@ -94,7 +92,7 @@ function ControlSystem(loss, deqs::AbstractVector{<:Equation}, iv, dvs, controls
9492
default_u0 isa Dict || (default_u0 = Dict(default_u0))
9593
default_p isa Dict || (default_p = Dict(default_p))
9694
ControlSystem(value(loss), deqs, iv′, dvs′, controls′,
97-
ps′, pins, observed, name, systems, default_u0, default_p)
95+
ps′, observed, name, systems, default_u0, default_p)
9896
end
9997

10098
struct ControlToExpr

0 commit comments

Comments
 (0)