Skip to content

Commit 745e889

Browse files
fix: fix incorrect indexes of array symbolics
1 parent 709148e commit 745e889

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/systems/index_cache.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ struct ParameterIndex{P, I}
1717
end
1818

1919
const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
20-
const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}}}
20+
const UnknownIndexMap = Dict{
21+
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, Array{Int}}}
2122

2223
struct IndexCache
2324
unknown_idx::UnknownIndexMap
@@ -46,11 +47,26 @@ function IndexCache(sys::AbstractSystem)
4647
end
4748
unk_idxs[usym] = sym_idx
4849

49-
if hasname(sym)
50+
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
5051
unk_idxs[getname(usym)] = sym_idx
5152
end
5253
idx += length(sym)
5354
end
55+
for sym in unks
56+
usym = unwrap(sym)
57+
istree(sym) && operation(sym) === getindex || continue
58+
arrsym = arguments(sym)[1]
59+
all(haskey(unk_idxs, arrsym[i]) for i in eachindex(arrsym)) || continue
60+
61+
idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)]
62+
if idxs == idxs[begin]:idxs[end]
63+
idxs = idxs[begin]:idxs[end]
64+
end
65+
unk_idxs[arrsym] = idxs
66+
if hasname(arrsym)
67+
unk_idxs[getname(arrsym)] = idxs
68+
end
69+
end
5470
end
5571

5672
disc_buffers = Dict{Any, Set{BasicSymbolic}}()

test/odesystem.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,3 +1120,16 @@ tearing_state = TearingState(expand_connections(sys))
11201120
ts_vars = tearing_state.fullvars
11211121
orig_vars = unknowns(sys)
11221122
@test isempty(setdiff(ts_vars, orig_vars))
1123+
1124+
# Ensure indexes of array symbolics are cached appropriately
1125+
@variables x(t)[1:2]
1126+
@named sys = ODESystem(Equation[], t, [x], [])
1127+
sys1 = complete(sys)
1128+
@named sys = ODESystem(Equation[], t, [x...], [])
1129+
sys2 = complete(sys)
1130+
for sys in [sys1, sys2]
1131+
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
1132+
@test is_variable(sys, sym)
1133+
@test variable_index(sys, sym) == idx
1134+
end
1135+
end

0 commit comments

Comments
 (0)