Skip to content

Commit 7c90426

Browse files
Merge pull request #3112 from AayushSabharwal/as/better-scoping
feat: better discover scoped parameters in parent systems
2 parents c44b645 + 5d7d97d commit 7c90426

File tree

4 files changed

+127
-11
lines changed

4 files changed

+127
-11
lines changed

src/systems/abstractsystem.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,10 +913,24 @@ end
913913
"""
914914
$(TYPEDSIGNATURES)
915915
916-
Mark a system as completed. If a system is complete, the system will no longer
916+
Mark a system as completed. A completed system is a system which is done being
917+
defined/modified and is ready for structural analysis or other transformations.
918+
This allows for analyses and optimizations to be performed which require knowing
919+
the global structure of the system.
920+
921+
One property to note is that if a system is complete, the system will no longer
917922
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
918923
"""
919924
function complete(sys::AbstractSystem; split = true)
925+
if !(sys isa JumpSystem)
926+
newunknowns = OrderedSet()
927+
newparams = OrderedSet()
928+
iv = has_iv(sys) ? get_iv(sys) : nothing
929+
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
930+
# don't update unknowns to not disturb `structural_simplify` order
931+
# `GlobalScope`d unknowns will be picked up and added there
932+
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
933+
end
920934
if split && has_index_cache(sys)
921935
@set! sys.index_cache = IndexCache(sys)
922936
all_ps = parameters(sys)
@@ -3011,6 +3025,14 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)
30113025
if has_is_dde(sys)
30123026
@set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys))
30133027
end
3028+
newunknowns = OrderedSet()
3029+
newparams = OrderedSet()
3030+
iv = has_iv(sys) ? get_iv(sys) : nothing
3031+
for ssys in systems
3032+
collect_scoped_vars!(newunknowns, newparams, ssys, iv)
3033+
end
3034+
@set! sys.unknowns = unique!(vcat(get_unknowns(sys), collect(newunknowns)))
3035+
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
30143036
return sys
30153037
end
30163038
function compose(syss...; name = nameof(first(syss)))

src/systems/diffeqs/odesystem.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,13 @@ function ODESystem(eqs, iv; kwargs...)
323323
collect_vars!(allunknowns, ps, eq.rhs, iv)
324324
if isdiffeq(eq)
325325
diffvar, _ = var_from_nested_derivative(eq.lhs)
326-
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
327-
throw(ArgumentError("An ODESystem can only have one independent variable."))
328-
diffvar in diffvars &&
329-
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
330-
push!(diffvars, diffvar)
326+
if check_scope_depth(getmetadata(diffvar, SymScope, LocalScope()), 0)
327+
isequal(iv, iv_from_nested_derivative(eq.lhs)) ||
328+
throw(ArgumentError("An ODESystem can only have one independent variable."))
329+
diffvar in diffvars &&
330+
throw(ArgumentError("The differential variable $diffvar is not unique in the system of equations."))
331+
push!(diffvars, diffvar)
332+
end
331333
push!(diffeq, eq)
332334
else
333335
push!(algeeq, eq)
@@ -342,6 +344,9 @@ function ODESystem(eqs, iv; kwargs...)
342344
collect_vars!(allunknowns, ps, eq.rhs, iv)
343345
end
344346
end
347+
for ssys in get(kwargs, :systems, ODESystem[])
348+
collect_scoped_vars!(allunknowns, ps, ssys, iv)
349+
end
345350
for v in allunknowns
346351
isdelay(v, iv) || continue
347352
collect_vars!(allunknowns, ps, arguments(v)[1], iv)

src/utils.jl

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -473,23 +473,56 @@ function find_derivatives!(vars, expr, f)
473473
return vars
474474
end
475475

476-
function collect_vars!(unknowns, parameters, expr, iv; op = Differential)
476+
"""
477+
$(TYPEDSIGNATURES)
478+
479+
Search through equations and parameter dependencies of `sys`, where sys is at a depth of
480+
`depth` from the root system, looking for variables scoped to the root system. Also
481+
recursively searches through all subsystems of `sys`, increasing the depth if it is not
482+
`-1`. A depth of `-1` indicates searching for variables with `GlobalScope`.
483+
"""
484+
function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential)
485+
if has_eqs(sys)
486+
for eq in get_eqs(sys)
487+
eq.lhs isa Union{Symbolic, Number} || continue
488+
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
489+
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
490+
end
491+
end
492+
if has_parameter_dependencies(sys)
493+
for eq in get_parameter_dependencies(sys)
494+
if eq isa Pair
495+
collect_vars!(unknowns, parameters, eq[1], iv; depth, op)
496+
collect_vars!(unknowns, parameters, eq[2], iv; depth, op)
497+
else
498+
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
499+
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
500+
end
501+
end
502+
end
503+
newdepth = depth == -1 ? depth : depth + 1
504+
for ssys in get_systems(sys)
505+
collect_scoped_vars!(unknowns, parameters, ssys, iv; depth = newdepth, op)
506+
end
507+
end
508+
509+
function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Differential)
477510
if issym(expr)
478-
collect_var!(unknowns, parameters, expr, iv)
511+
collect_var!(unknowns, parameters, expr, iv; depth)
479512
else
480513
for var in vars(expr; op)
481514
if iscall(var) && operation(var) isa Differential
482515
var, _ = var_from_nested_derivative(var)
483516
end
484-
collect_var!(unknowns, parameters, var, iv)
517+
collect_var!(unknowns, parameters, var, iv; depth)
485518
end
486519
end
487520
return nothing
488521
end
489522

490-
function collect_var!(unknowns, parameters, var, iv)
523+
function collect_var!(unknowns, parameters, var, iv; depth = 0)
491524
isequal(var, iv) && return nothing
492-
getmetadata(var, SymScope, LocalScope()) == LocalScope() || return nothing
525+
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
493526
if iscalledparameter(var)
494527
callable = getcalledparameter(var)
495528
push!(parameters, callable)
@@ -506,6 +539,24 @@ function collect_var!(unknowns, parameters, var, iv)
506539
return nothing
507540
end
508541

542+
"""
543+
$(TYPEDSIGNATURES)
544+
545+
Check if the given `scope` is at a depth of `depth` from the root system. Only
546+
returns `true` for `scope::GlobalScope` if `depth == -1`.
547+
"""
548+
function check_scope_depth(scope, depth)
549+
if scope isa LocalScope
550+
return depth == 0
551+
elseif scope isa ParentScope
552+
return depth > 0 && check_scope_depth(scope.parent, depth - 1)
553+
elseif scope isa DelayParentScope
554+
return depth >= scope.N && check_scope_depth(scope.parent, depth - scope.N)
555+
elseif scope isa GlobalScope
556+
return depth == -1
557+
end
558+
end
559+
509560
"""
510561
Find all the symbolic constants of some equations or terms and return them as a vector.
511562
"""

test/variable_scope.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,41 @@ bar = complete(bar)
101101
defs = ModelingToolkit.defaults(bar)
102102
@test defs[bar.p] == 2
103103
@test isequal(defs[bar.foo.p], bar.p)
104+
105+
# Issue#3101
106+
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
107+
x2 = ParentScope(x2)
108+
x3 = ParentScope(ParentScope(x3))
109+
x4 = DelayParentScope(x4, 2)
110+
x5 = GlobalScope(x5)
111+
@parameters p1 p2 p3 p4 p5
112+
p2 = ParentScope(p2)
113+
p3 = ParentScope(ParentScope(p3))
114+
p4 = DelayParentScope(p4, 2)
115+
p5 = GlobalScope(p5)
116+
117+
@named sys1 = ODESystem([D(x1) ~ p1, D(x2) ~ p2, D(x3) ~ p3, D(x4) ~ p4, D(x5) ~ p5], t)
118+
@test isequal(x1, only(unknowns(sys1)))
119+
@test isequal(p1, only(parameters(sys1)))
120+
@named sys2 = ODESystem(Equation[], t; systems = [sys1])
121+
@test length(unknowns(sys2)) == 2
122+
@test any(isequal(x2), unknowns(sys2))
123+
@test length(parameters(sys2)) == 2
124+
@test any(isequal(p2), parameters(sys2))
125+
@named sys3 = ODESystem(Equation[], t)
126+
sys3 = sys3 sys2
127+
@test length(unknowns(sys3)) == 4
128+
@test any(isequal(x3), unknowns(sys3))
129+
@test any(isequal(x4), unknowns(sys3))
130+
@test length(parameters(sys3)) == 4
131+
@test any(isequal(p3), parameters(sys3))
132+
@test any(isequal(p4), parameters(sys3))
133+
sys4 = complete(sys3)
134+
@test length(unknowns(sys3)) == 4
135+
@test length(parameters(sys4)) == 5
136+
@test any(isequal(p5), parameters(sys4))
137+
sys5 = structural_simplify(sys3)
138+
@test length(unknowns(sys5)) == 5
139+
@test any(isequal(x5), unknowns(sys5))
140+
@test length(parameters(sys5)) == 5
141+
@test any(isequal(p5), parameters(sys5))

0 commit comments

Comments
 (0)