Skip to content

Commit dad05e5

Browse files
Merge pull request #3284 from AayushSabharwal/as/indexing-hotfix
fix: fix `timeseries_parameter_index` for array symbolics
2 parents 4792360 + 1835a56 commit dad05e5

File tree

4 files changed

+18
-3
lines changed

4 files changed

+18
-3
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ function flatten_equations(eqs)
12301230
error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar")
12311231
size(eq.lhs) == size(eq.rhs) ||
12321232
error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))")
1233-
return collect(eq.lhs) .~ collect(eq.rhs)
1233+
return vec(collect(eq.lhs) .~ collect(eq.rhs))
12341234
else
12351235
eq
12361236
end

src/systems/index_cache.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,16 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy
404404
sym = get(ic.symbol_to_variable, sym, nothing)
405405
sym === nothing && return nothing
406406
end
407+
sym = unwrap(sym)
407408
idx = check_index_map(ic.discrete_idx, sym)
408409
idx === nothing ||
409410
return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock))
410411
iscall(sym) && operation(sym) == getindex || return nothing
411412
args = arguments(sym)
412413
idx = timeseries_parameter_index(ic, args[1])
413414
idx === nothing && return nothing
414-
ParameterIndex(idx.portion, (idx.idx..., args[2:end]...), idx.validate_size)
415+
return ParameterTimeseriesIndex(
416+
idx.timeseries_idx, (idx.parameter_idx..., args[2:end]...))
415417
end
416418

417419
function check_index_map(idxmap, sym)

test/if_lifting.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ using ModelingToolkit: t_nounits as t, D_nounits as D, IfLifting, no_if_lift
2121
@test operation(only(equations(ss2)).rhs) === ifelse
2222

2323
discvar = only(parameters(ss2))
24-
prob2 = ODEProblem(ss2, [x => 0.0], (0.0, 5.0))
24+
prob1 = ODEProblem(ss1, [ss1.x => 0.0], (0.0, 5.0))
25+
sol1 = solve(prob1, Tsit5())
26+
prob2 = ODEProblem(ss2, [ss2.x => 0.0], (0.0, 5.0))
2527
sol2 = solve(prob2, Tsit5())
2628
@test count(isapprox(pi), sol2.t) == 2
2729
@test any(isapprox(pi), sol2.discretes[1].t)

test/symbolic_indexing_interface.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,14 @@ end
224224
end
225225
@test isempty(get_all_timeseries_indexes(sys, a))
226226
end
227+
228+
@testset "`timeseries_parameter_index` on unwrapped scalarized timeseries parameter" begin
229+
@variables x(t)[1:2]
230+
@parameters p(t)[1:2, 1:2]
231+
ev = [x[1] ~ 2.0] => [p ~ -ones(2, 2)]
232+
@mtkbuild sys = ODESystem(D(x) ~ p * x, t; continuous_events = [ev])
233+
p = ModelingToolkit.unwrap(p)
234+
@test timeseries_parameter_index(sys, p) === ParameterTimeseriesIndex(1, (1, 1))
235+
@test timeseries_parameter_index(sys, p[1, 1]) ===
236+
ParameterTimeseriesIndex(1, (1, 1, 1, 1))
237+
end

0 commit comments

Comments
 (0)