Skip to content

Commit 0db9cd2

Browse files
Merge pull request #2647 from AayushSabharwal/as/fix-index-caching
fix: fix incorrect indexes of array symbolics
2 parents 709148e + c9498e2 commit 0db9cd2

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

src/systems/index_cache.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ struct ParameterIndex{P, I}
1717
end
1818

1919
const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
20-
const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}}}
20+
const UnknownIndexMap = Dict{
21+
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
2122

2223
struct IndexCache
2324
unknown_idx::UnknownIndexMap
@@ -40,17 +41,32 @@ function IndexCache(sys::AbstractSystem)
4041
for sym in unks
4142
usym = unwrap(sym)
4243
sym_idx = if Symbolics.isarraysymbolic(sym)
43-
idx:(idx + length(sym) - 1)
44+
reshape(idx:(idx + length(sym) - 1), size(sym))
4445
else
4546
idx
4647
end
4748
unk_idxs[usym] = sym_idx
4849

49-
if hasname(sym)
50+
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
5051
unk_idxs[getname(usym)] = sym_idx
5152
end
5253
idx += length(sym)
5354
end
55+
for sym in unks
56+
usym = unwrap(sym)
57+
istree(sym) && operation(sym) === getindex || continue
58+
arrsym = arguments(sym)[1]
59+
all(haskey(unk_idxs, arrsym[i]) for i in eachindex(arrsym)) || continue
60+
61+
idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)]
62+
if idxs == idxs[begin]:idxs[end]
63+
idxs = reshape(idxs[begin]:idxs[end], size(idxs))
64+
end
65+
unk_idxs[arrsym] = idxs
66+
if hasname(arrsym)
67+
unk_idxs[getname(arrsym)] = idxs
68+
end
69+
end
5470
end
5571

5672
disc_buffers = Dict{Any, Set{BasicSymbolic}}()
@@ -124,7 +140,7 @@ function IndexCache(sys::AbstractSystem)
124140
for (j, p) in enumerate(buf)
125141
idxs[p] = (i, j)
126142
idxs[default_toterm(p)] = (i, j)
127-
if hasname(p)
143+
if hasname(p) && (!istree(p) || operation(p) !== getindex)
128144
idxs[getname(p)] = (i, j)
129145
idxs[getname(default_toterm(p))] = (i, j)
130146
end

test/odesystem.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,3 +1120,30 @@ tearing_state = TearingState(expand_connections(sys))
11201120
ts_vars = tearing_state.fullvars
11211121
orig_vars = unknowns(sys)
11221122
@test isempty(setdiff(ts_vars, orig_vars))
1123+
1124+
# Ensure indexes of array symbolics are cached appropriately
1125+
@variables x(t)[1:2]
1126+
@named sys = ODESystem(Equation[], t, [x], [])
1127+
sys1 = complete(sys)
1128+
@named sys = ODESystem(Equation[], t, [x...], [])
1129+
sys2 = complete(sys)
1130+
for sys in [sys1, sys2]
1131+
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
1132+
@test is_variable(sys, sym)
1133+
@test variable_index(sys, sym) == idx
1134+
end
1135+
end
1136+
1137+
@variables x(t)[1:2, 1:2]
1138+
@named sys = ODESystem(Equation[], t, [x], [])
1139+
sys1 = complete(sys)
1140+
@named sys = ODESystem(Equation[], t, [x...], [])
1141+
sys2 = complete(sys)
1142+
for sys in [sys1, sys2]
1143+
@test is_variable(sys, x)
1144+
@test variable_index(sys, x) == [1 3; 2 4]
1145+
for i in eachindex(x)
1146+
@test is_variable(sys, x[i])
1147+
@test variable_index(sys, x[i]) == variable_index(sys, x)[i]
1148+
end
1149+
end

0 commit comments

Comments
 (0)