Skip to content

Commit 802ac1d

Browse files
fix: fix SymScope metadata for array variables
Co-authored-by: contradict <[email protected]>
1 parent 7bc758b commit 802ac1d

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

src/systems/abstractsystem.jl

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -746,38 +746,64 @@ end
746746
abstract type SymScope end
747747

748748
struct LocalScope <: SymScope end
749-
function LocalScope(sym::Union{Num, Symbolic})
749+
function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
750750
apply_to_variables(sym) do sym
751-
setmetadata(sym, SymScope, LocalScope())
751+
if istree(sym) && operation(sym) === getindex
752+
args = arguments(sym)
753+
a1 = setmetadata(args[1], SymScope, LocalScope())
754+
similarterm(sym, operation(sym), [a1, args[2:end]...])
755+
else
756+
setmetadata(sym, SymScope, LocalScope())
757+
end
752758
end
753759
end
754760

755761
struct ParentScope <: SymScope
756762
parent::SymScope
757763
end
758-
function ParentScope(sym::Union{Num, Symbolic})
764+
function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
759765
apply_to_variables(sym) do sym
760-
setmetadata(sym, SymScope,
761-
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
766+
if istree(sym) && operation(sym) == getindex
767+
args = arguments(sym)
768+
a1 = setmetadata(args[1], SymScope,
769+
ParentScope(getmetadata(value(args[1]), SymScope, LocalScope())))
770+
similarterm(sym, operation(sym), [a1, args[2:end]...])
771+
else
772+
setmetadata(sym, SymScope,
773+
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
774+
end
762775
end
763776
end
764777

765778
struct DelayParentScope <: SymScope
766779
parent::SymScope
767780
N::Int
768781
end
769-
function DelayParentScope(sym::Union{Num, Symbolic}, N)
782+
function DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}, N)
770783
apply_to_variables(sym) do sym
771-
setmetadata(sym, SymScope,
772-
DelayParentScope(getmetadata(value(sym), SymScope, LocalScope()), N))
784+
if istree(sym) && operation(sym) == getindex
785+
args = arguments(sym)
786+
a1 = setmetadata(args[1], SymScope,
787+
DelayParentScope(getmetadata(value(args[1]), SymScope, LocalScope()), N))
788+
similarterm(sym, operation(sym), [a1, args[2:end]...])
789+
else
790+
setmetadata(sym, SymScope,
791+
DelayParentScope(getmetadata(value(sym), SymScope, LocalScope()), N))
792+
end
773793
end
774794
end
775-
DelayParentScope(sym::Union{Num, Symbolic}) = DelayParentScope(sym, 1)
795+
DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) = DelayParentScope(sym, 1)
776796

777797
struct GlobalScope <: SymScope end
778-
function GlobalScope(sym::Union{Num, Symbolic})
798+
function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
779799
apply_to_variables(sym) do sym
780-
setmetadata(sym, SymScope, GlobalScope())
800+
if istree(sym) && operation(sym) == getindex
801+
args = arguments(sym)
802+
a1 = setmetadata(args[1], SymScope, GlobalScope())
803+
similarterm(sym, operation(sym), [a1, args[2:end]...])
804+
else
805+
setmetadata(sym, SymScope, GlobalScope())
806+
end
781807
end
782808
end
783809

@@ -1500,8 +1526,7 @@ function default_to_parentscope(v)
15001526
uv isa Symbolic || return v
15011527
apply_to_variables(v) do sym
15021528
if !hasmetadata(uv, SymScope)
1503-
setmetadata(sym, SymScope,
1504-
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
1529+
ParentScope(sym)
15051530
else
15061531
sym
15071532
end

test/variable_scope.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,13 @@ ps = ModelingToolkit.getname.(parameters(level3))
7373
@test isequal(ps[4], :level2₊level0₊d)
7474
@test isequal(ps[5], :level1₊level0₊e)
7575
@test isequal(ps[6], :f)
76+
77+
# Issue@2252
78+
# Tests from PR#2354
79+
@parameters xx[1:2]
80+
arr_p = [ParentScope(xx[1]), xx[2]]
81+
arr0 = ODESystem(Equation[], t, [], arr_p; name = :arr0)
82+
arr1 = ODESystem(Equation[], t, [], []; name = :arr1) arr0
83+
arr_ps = ModelingToolkit.getname.(parameters(arr1))
84+
@test isequal(arr_ps[1], Symbol("xx"))
85+
@test isequal(arr_ps[2], Symbol("arr0₊xx"))

0 commit comments

Comments
 (0)