Skip to content

Commit c9498e2

Browse files
fix: fix IndexCache stored indices for array unknowns
1 parent 745e889 commit c9498e2

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/systems/index_cache.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ end
1818

1919
const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
2020
const UnknownIndexMap = Dict{
21-
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, Array{Int}}}
21+
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
2222

2323
struct IndexCache
2424
unknown_idx::UnknownIndexMap
@@ -41,7 +41,7 @@ function IndexCache(sys::AbstractSystem)
4141
for sym in unks
4242
usym = unwrap(sym)
4343
sym_idx = if Symbolics.isarraysymbolic(sym)
44-
idx:(idx + length(sym) - 1)
44+
reshape(idx:(idx + length(sym) - 1), size(sym))
4545
else
4646
idx
4747
end
@@ -60,7 +60,7 @@ function IndexCache(sys::AbstractSystem)
6060

6161
idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)]
6262
if idxs == idxs[begin]:idxs[end]
63-
idxs = idxs[begin]:idxs[end]
63+
idxs = reshape(idxs[begin]:idxs[end], size(idxs))
6464
end
6565
unk_idxs[arrsym] = idxs
6666
if hasname(arrsym)
@@ -140,7 +140,7 @@ function IndexCache(sys::AbstractSystem)
140140
for (j, p) in enumerate(buf)
141141
idxs[p] = (i, j)
142142
idxs[default_toterm(p)] = (i, j)
143-
if hasname(p)
143+
if hasname(p) && (!istree(p) || operation(p) !== getindex)
144144
idxs[getname(p)] = (i, j)
145145
idxs[getname(default_toterm(p))] = (i, j)
146146
end

test/odesystem.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,3 +1133,17 @@ for sys in [sys1, sys2]
11331133
@test variable_index(sys, sym) == idx
11341134
end
11351135
end
1136+
1137+
@variables x(t)[1:2, 1:2]
1138+
@named sys = ODESystem(Equation[], t, [x], [])
1139+
sys1 = complete(sys)
1140+
@named sys = ODESystem(Equation[], t, [x...], [])
1141+
sys2 = complete(sys)
1142+
for sys in [sys1, sys2]
1143+
@test is_variable(sys, x)
1144+
@test variable_index(sys, x) == [1 3; 2 4]
1145+
for i in eachindex(x)
1146+
@test is_variable(sys, x[i])
1147+
@test variable_index(sys, x[i]) == variable_index(sys, x)[i]
1148+
end
1149+
end

0 commit comments

Comments
 (0)