Skip to content

Commit 9974304

Browse files
feat: better discover scoped parameters in parent systems
1 parent 816fde7 commit 9974304

File tree

4 files changed

+116
-10
lines changed

4 files changed

+116
-10
lines changed

src/systems/abstractsystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,13 @@ Mark a system as completed. If a system is complete, the system will no longer
917917
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
918918
"""
919919
function complete(sys::AbstractSystem; split = true)
920+
newunknowns = OrderedSet()
921+
newparams = OrderedSet()
922+
iv = has_iv(sys) ? get_iv(sys) : nothing
923+
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
924+
# don't update unknowns to not disturb `structural_simplify` order
925+
# `GlobalScope`d unknowns will be picked up and added there
926+
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
920927
if split && has_index_cache(sys)
921928
@set! sys.index_cache = IndexCache(sys)
922929
all_ps = parameters(sys)
@@ -3011,6 +3018,14 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)
30113018
if has_is_dde(sys)
30123019
@set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys))
30133020
end
3021+
newunknowns = OrderedSet()
3022+
newparams = OrderedSet()
3023+
iv = has_iv(sys) ? get_iv(sys) : nothing
3024+
for ssys in systems
3025+
collect_scoped_vars!(newunknowns, newparams, ssys, iv)
3026+
end
3027+
@set! sys.unknowns = unique!(vcat(get_unknowns(sys), collect(newunknowns)))
3028+
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
30143029
return sys
30153030
end
30163031
function compose(syss...; name = nameof(first(syss)))

src/systems/diffeqs/odesystem.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,13 @@ function ODESystem(eqs, iv; kwargs...)
323323
collect_vars!(allunknowns, ps, eq.rhs, iv)
324324
if isdiffeq(eq)
325325
diffvar, _ = var_from_nested_derivative(eq.lhs)
326-
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
327-
throw(ArgumentError("An ODESystem can only have one independent variable."))
328-
diffvar in diffvars &&
329-
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
330-
push!(diffvars, diffvar)
326+
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
327+
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
328+
throw(ArgumentError("An ODESystem can only have one independent variable."))
329+
diffvar in diffvars &&
330+
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
331+
push!(diffvars, diffvar)
332+
end
331333
push!(diffeq, eq)
332334
else
333335
push!(algeeq, eq)
@@ -342,6 +344,9 @@ function ODESystem(eqs, iv; kwargs...)
342344
collect_vars!(allunknowns, ps, eq.rhs, iv)
343345
end
344346
end
347+
for ssys in get(kwargs, :systems, ODESystem[])
348+
collect_scoped_vars!(allunknowns, ps, ssys, iv)
349+
end
345350
for v in allunknowns
346351
isdelay(v, iv) || continue
347352
collect_vars!(allunknowns, ps, arguments(v)[1], iv)

src/utils.jl

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,23 +467,53 @@ function find_derivatives!(vars, expr, f)
467467
return vars
468468
end
469469

470-
function collect_vars!(unknowns, parameters, expr, iv; op = Differential)
470+
"""
471+
$(TYPEDSIGNATURES)
472+
473+
Search through equations and parameter dependencies of `sys`, where sys is at a depth of
474+
`depth` from the root system, looking for variables scoped to the root system. Also
475+
recursively searches through all subsystems of `sys`, increasing the depth if it is not
476+
`-1`. A depth of `-1` indicates searching for variables with `GlobalScope`.
477+
"""
478+
function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential)
479+
for eq in get_eqs(sys)
480+
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
481+
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
482+
end
483+
if has_parameter_dependencies(sys)
484+
for eq in get_parameter_dependencies(sys)
485+
if eq isa Pair
486+
collect_vars!(unknowns, parameters, eq[1], iv; depth, op)
487+
collect_vars!(unknowns, parameters, eq[2], iv; depth, op)
488+
else
489+
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
490+
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
491+
end
492+
end
493+
end
494+
newdepth = depth == -1 ? depth : depth + 1
495+
for ssys in get_systems(sys)
496+
collect_scoped_vars!(unknowns, parameters, ssys, iv; depth = newdepth, op)
497+
end
498+
end
499+
500+
function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Differential)
471501
if issym(expr)
472-
collect_var!(unknowns, parameters, expr, iv)
502+
collect_var!(unknowns, parameters, expr, iv; depth)
473503
else
474504
for var in vars(expr; op)
475505
if iscall(var) && operation(var) isa Differential
476506
var, _ = var_from_nested_derivative(var)
477507
end
478-
collect_var!(unknowns, parameters, var, iv)
508+
collect_var!(unknowns, parameters, var, iv; depth)
479509
end
480510
end
481511
return nothing
482512
end
483513

484-
function collect_var!(unknowns, parameters, var, iv)
514+
function collect_var!(unknowns, parameters, var, iv; depth = 0)
485515
isequal(var, iv) && return nothing
486-
getmetadata(var, SymScope, LocalScope()) == LocalScope() || return nothing
516+
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
487517
if iscalledparameter(var)
488518
callable = getcalledparameter(var)
489519
push!(parameters, callable)
@@ -500,6 +530,24 @@ function collect_var!(unknowns, parameters, var, iv)
500530
return nothing
501531
end
502532

533+
"""
534+
$(TYPEDSIGNATURES)
535+
536+
Check if the given `scope` is at a depth of `depth` from the root system. Only
537+
returns `true` for `scope::GlobalScope` if `depth == -1`.
538+
"""
539+
function check_scope_depth(scope, depth)
540+
if scope isa LocalScope
541+
return depth == 0
542+
elseif scope isa ParentScope
543+
return depth > 0 && check_scope_depth(scope.parent, depth - 1)
544+
elseif scope isa DelayParentScope
545+
return depth >= scope.N && check_scope_depth(scope.parent, depth - scope.N)
546+
elseif scope isa GlobalScope
547+
return depth == -1
548+
end
549+
end
550+
503551
"""
504552
Find all the symbolic constants of some equations or terms and return them as a vector.
505553
"""

test/variable_scope.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,41 @@ bar = complete(bar)
101101
defs = ModelingToolkit.defaults(bar)
102102
@test defs[bar.p] == 2
103103
@test isequal(defs[bar.foo.p], bar.p)
104+
105+
# Issue#3101
106+
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
107+
x2 = ParentScope(x2)
108+
x3 = ParentScope(ParentScope(x3))
109+
x4 = DelayParentScope(x4, 2)
110+
x5 = GlobalScope(x5)
111+
@parameters p1 p2 p3 p4 p5
112+
p2 = ParentScope(p2)
113+
p3 = ParentScope(ParentScope(p3))
114+
p4 = DelayParentScope(p4, 2)
115+
p5 = GlobalScope(p5)
116+
117+
@named sys1 = ODESystem([D(x1) ~ p1, D(x2) ~ p2, D(x3) ~ p3, D(x4) ~ p4, D(x5) ~ p5], t)
118+
@test isequal(x1, only(unknowns(sys1)))
119+
@test isequal(p1, only(parameters(sys1)))
120+
@named sys2 = ODESystem(Equation[], t; systems = [sys1])
121+
@test length(unknowns(sys2)) == 2
122+
@test any(isequal(x2), unknowns(sys2))
123+
@test length(parameters(sys2)) == 2
124+
@test any(isequal(p2), parameters(sys2))
125+
@named sys3 = ODESystem(Equation[], t)
126+
sys3 = sys3 sys2
127+
@test length(unknowns(sys3)) == 4
128+
@test any(isequal(x3), unknowns(sys3))
129+
@test any(isequal(x4), unknowns(sys3))
130+
@test length(parameters(sys3)) == 4
131+
@test any(isequal(p3), parameters(sys3))
132+
@test any(isequal(p4), parameters(sys3))
133+
sys4 = complete(sys3)
134+
@test length(unknowns(sys3)) == 4
135+
@test length(parameters(sys4)) == 5
136+
@test any(isequal(p5), parameters(sys4))
137+
sys5 = structural_simplify(sys3)
138+
@test length(unknowns(sys5)) == 5
139+
@test any(isequal(x5), unknowns(sys5))
140+
@test length(parameters(sys5)) == 5
141+
@test any(isequal(p5), parameters(sys5))

0 commit comments

Comments
 (0)