Skip to content

Commit da44292

Browse files
feat: better discover scoped parameters in parent systems
1 parent 816fde7 commit da44292

File tree

4 files changed

+121
-10
lines changed

4 files changed

+121
-10
lines changed

src/systems/abstractsystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,15 @@ Mark a system as completed. If a system is complete, the system will no longer
917917
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
918918
"""
919919
function complete(sys::AbstractSystem; split = true)
920+
if !(sys isa JumpSystem)
921+
newunknowns = OrderedSet()
922+
newparams = OrderedSet()
923+
iv = has_iv(sys) ? get_iv(sys) : nothing
924+
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
925+
# don't update unknowns to not disturb `structural_simplify` order
926+
# `GlobalScope`d unknowns will be picked up and added there
927+
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
928+
end
920929
if split && has_index_cache(sys)
921930
@set! sys.index_cache = IndexCache(sys)
922931
all_ps = parameters(sys)
@@ -3011,6 +3020,14 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)
30113020
if has_is_dde(sys)
30123021
@set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys))
30133022
end
3023+
newunknowns = OrderedSet()
3024+
newparams = OrderedSet()
3025+
iv = has_iv(sys) ? get_iv(sys) : nothing
3026+
for ssys in systems
3027+
collect_scoped_vars!(newunknowns, newparams, ssys, iv)
3028+
end
3029+
@set! sys.unknowns = unique!(vcat(get_unknowns(sys), collect(newunknowns)))
3030+
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
30143031
return sys
30153032
end
30163033
function compose(syss...; name = nameof(first(syss)))

src/systems/diffeqs/odesystem.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,13 @@ function ODESystem(eqs, iv; kwargs...)
323323
collect_vars!(allunknowns, ps, eq.rhs, iv)
324324
if isdiffeq(eq)
325325
diffvar, _ = var_from_nested_derivative(eq.lhs)
326-
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
327-
throw(ArgumentError("An ODESystem can only have one independent variable."))
328-
diffvar in diffvars &&
329-
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
330-
push!(diffvars, diffvar)
326+
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
327+
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
328+
throw(ArgumentError("An ODESystem can only have one independent variable."))
329+
diffvar in diffvars &&
330+
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
331+
push!(diffvars, diffvar)
332+
end
331333
push!(diffeq, eq)
332334
else
333335
push!(algeeq, eq)
@@ -342,6 +344,9 @@ function ODESystem(eqs, iv; kwargs...)
342344
collect_vars!(allunknowns, ps, eq.rhs, iv)
343345
end
344346
end
347+
for ssys in get(kwargs, :systems, ODESystem[])
348+
collect_scoped_vars!(allunknowns, ps, ssys, iv)
349+
end
345350
for v in allunknowns
346351
isdelay(v, iv) || continue
347352
collect_vars!(allunknowns, ps, arguments(v)[1], iv)

src/utils.jl

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,23 +467,56 @@ function find_derivatives!(vars, expr, f)
467467
return vars
468468
end
469469

470-
function collect_vars!(unknowns, parameters, expr, iv; op = Differential)
470+
"""
471+
$(TYPEDSIGNATURES)
472+
473+
Search through equations and parameter dependencies of `sys`, where sys is at a depth of
474+
`depth` from the root system, looking for variables scoped to the root system. Also
475+
recursively searches through all subsystems of `sys`, increasing the depth if it is not
476+
`-1`. A depth of `-1` indicates searching for variables with `GlobalScope`.
477+
"""
478+
function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential)
479+
if has_eqs(sys)
480+
for eq in get_eqs(sys)
481+
eq.lhs isa Union{Symbolic, Number} || continue
482+
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
483+
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
484+
end
485+
end
486+
if has_parameter_dependencies(sys)
487+
for eq in get_parameter_dependencies(sys)
488+
if eq isa Pair
489+
collect_vars!(unknowns, parameters, eq[1], iv; depth, op)
490+
collect_vars!(unknowns, parameters, eq[2], iv; depth, op)
491+
else
492+
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
493+
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
494+
end
495+
end
496+
end
497+
newdepth = depth == -1 ? depth : depth + 1
498+
for ssys in get_systems(sys)
499+
collect_scoped_vars!(unknowns, parameters, ssys, iv; depth = newdepth, op)
500+
end
501+
end
502+
503+
function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Differential)
471504
if issym(expr)
472-
collect_var!(unknowns, parameters, expr, iv)
505+
collect_var!(unknowns, parameters, expr, iv; depth)
473506
else
474507
for var in vars(expr; op)
475508
if iscall(var) && operation(var) isa Differential
476509
var, _ = var_from_nested_derivative(var)
477510
end
478-
collect_var!(unknowns, parameters, var, iv)
511+
collect_var!(unknowns, parameters, var, iv; depth)
479512
end
480513
end
481514
return nothing
482515
end
483516

484-
function collect_var!(unknowns, parameters, var, iv)
517+
function collect_var!(unknowns, parameters, var, iv; depth = 0)
485518
isequal(var, iv) && return nothing
486-
getmetadata(var, SymScope, LocalScope()) == LocalScope() || return nothing
519+
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
487520
if iscalledparameter(var)
488521
callable = getcalledparameter(var)
489522
push!(parameters, callable)
@@ -500,6 +533,24 @@ function collect_var!(unknowns, parameters, var, iv)
500533
return nothing
501534
end
502535

536+
"""
537+
$(TYPEDSIGNATURES)
538+
539+
Check if the given `scope` is at a depth of `depth` from the root system. Only
540+
returns `true` for `scope::GlobalScope` if `depth == -1`.
541+
"""
542+
function check_scope_depth(scope, depth)
543+
if scope isa LocalScope
544+
return depth == 0
545+
elseif scope isa ParentScope
546+
return depth > 0 && check_scope_depth(scope.parent, depth - 1)
547+
elseif scope isa DelayParentScope
548+
return depth >= scope.N && check_scope_depth(scope.parent, depth - scope.N)
549+
elseif scope isa GlobalScope
550+
return depth == -1
551+
end
552+
end
553+
503554
"""
504555
Find all the symbolic constants of some equations or terms and return them as a vector.
505556
"""

test/variable_scope.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,41 @@ bar = complete(bar)
101101
defs = ModelingToolkit.defaults(bar)
102102
@test defs[bar.p] == 2
103103
@test isequal(defs[bar.foo.p], bar.p)
104+
105+
# Issue#3101
106+
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
107+
x2 = ParentScope(x2)
108+
x3 = ParentScope(ParentScope(x3))
109+
x4 = DelayParentScope(x4, 2)
110+
x5 = GlobalScope(x5)
111+
@parameters p1 p2 p3 p4 p5
112+
p2 = ParentScope(p2)
113+
p3 = ParentScope(ParentScope(p3))
114+
p4 = DelayParentScope(p4, 2)
115+
p5 = GlobalScope(p5)
116+
117+
@named sys1 = ODESystem([D(x1) ~ p1, D(x2) ~ p2, D(x3) ~ p3, D(x4) ~ p4, D(x5) ~ p5], t)
118+
@test isequal(x1, only(unknowns(sys1)))
119+
@test isequal(p1, only(parameters(sys1)))
120+
@named sys2 = ODESystem(Equation[], t; systems = [sys1])
121+
@test length(unknowns(sys2)) == 2
122+
@test any(isequal(x2), unknowns(sys2))
123+
@test length(parameters(sys2)) == 2
124+
@test any(isequal(p2), parameters(sys2))
125+
@named sys3 = ODESystem(Equation[], t)
126+
sys3 = sys3 sys2
127+
@test length(unknowns(sys3)) == 4
128+
@test any(isequal(x3), unknowns(sys3))
129+
@test any(isequal(x4), unknowns(sys3))
130+
@test length(parameters(sys3)) == 4
131+
@test any(isequal(p3), parameters(sys3))
132+
@test any(isequal(p4), parameters(sys3))
133+
sys4 = complete(sys3)
134+
@test length(unknowns(sys3)) == 4
135+
@test length(parameters(sys4)) == 5
136+
@test any(isequal(p5), parameters(sys4))
137+
sys5 = structural_simplify(sys3)
138+
@test length(unknowns(sys5)) == 5
139+
@test any(isequal(x5), unknowns(sys5))
140+
@test length(parameters(sys5)) == 5
141+
@test any(isequal(p5), parameters(sys5))

0 commit comments

Comments
 (0)