Skip to content

Commit 4ee1d48

Browse files
Merge pull request #779 from AayushSabharwal/as/skip-discrete-interpolation
fix: don't interpolate discretes if not interpolating discrete symbolic
2 parents 1418e40 + b660818 commit 4ee1d48

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/solutions/ode_solutions.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ function get_interpolated_discretes(sol::AbstractODESolution, t, deriv, continui
207207
return ParameterTimeseriesCollection(interp_discs, parameter_values(discs))
208208
end
209209

210+
function is_discrete_expression(indp, expr)
211+
ts_idxs = get_all_timeseries_indexes(indp, expr)
212+
length(ts_idxs) > 1 || length(ts_idxs) == 1 && only(ts_idxs) != ContinuousTimeseries()
213+
end
214+
210215
function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
211216
continuity = :left) where {deriv}
212217
if t isa IndexedClock
@@ -270,7 +275,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
270275
if is_parameter(sol, idxs) && !is_timeseries_parameter(sol, idxs)
271276
return getp(sol, idxs)(ps)
272277
end
273-
if is_parameter_timeseries(sol) == Timeseries()
278+
if is_parameter_timeseries(sol) == Timeseries() && is_discrete_expression(sol, idxs)
274279
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
275280
ps = parameter_values(discs)
276281
for ts_idx in eachindex(discs)
@@ -292,7 +297,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
292297
end
293298
error_if_observed_derivative(sol, idxs, deriv)
294299
ps = parameter_values(sol)
295-
if is_parameter_timeseries(sol) == Timeseries()
300+
if is_parameter_timeseries(sol) == Timeseries() && is_discrete_expression(sol, idxs)
296301
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
297302
ps = parameter_values(discs)
298303
for ts_idx in eachindex(discs)
@@ -312,7 +317,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
312317
error_if_observed_derivative(sol, idxs, deriv)
313318
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
314319
getter = getu(sol, idxs)
315-
if is_parameter_timeseries(sol) == NotTimeseries()
320+
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
316321
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol)
317322
return DiffEqArray(getter(interp_sol), t, p, sol)
318323
end
@@ -333,7 +338,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
333338
error_if_observed_derivative(sol, idxs, deriv)
334339
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
335340
getter = getu(sol, idxs)
336-
if is_parameter_timeseries(sol) == NotTimeseries()
341+
if is_parameter_timeseries(sol) == NotTimeseries() || !is_discrete_expression(sol, idxs)
337342
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol)
338343
return DiffEqArray(getter(interp_sol), t, p, sol)
339344
end

test/downstream/comprehensive_indexing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,3 +908,17 @@ end
908908
end
909909
end
910910
end
911+
912+
# Issue https://github.com/SciML/ModelingToolkit.jl/issues/3004
913+
@testset "Continuous interpolation before discrete save" begin
914+
@variables x(t)
915+
@parameters c(t)
916+
@mtkbuild sys = ODESystem(
917+
D(x) ~ c * cos(x), t, [x], [c]; discrete_events = [1.0 => [c ~ c + 1]])
918+
prob = ODEProblem(sys, [x => 0.0], (0.0, 2pi), [c => 1.0])
919+
sol = solve(prob, Tsit5())
920+
@test_nowarn sol(-0.1; idxs = sys.x)
921+
@test_nowarn sol(-0.1; idxs = [sys.x, 2sys.x])
922+
@test_throws ErrorException sol(-0.1; idxs = sys.c)
923+
@test_throws ErrorException sol(-0.1; idxs = [sys.x, sys.x + sys.c])
924+
end

0 commit comments

Comments
 (0)