Skip to content

Commit 0144991

Browse files
Merge pull request #935 from AayushSabharwal/as/fix-disc-interp
feat: allow interpolating discretes past the final time point
2 parents d819c31 + 8d2aa53 commit 0144991

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

src/remake.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,10 @@ calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. R
10831083
the updated `newu0` and `newp`.
10841084
"""
10851085
function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
1086+
if hasmethod(symbolic_container, Tuple{typeof(root_indp)}) &&
1087+
(sc = symbolic_container(root_indp)) !== root_indp
1088+
return late_binding_update_u0_p(prob, sc, u0, p, t0, newu0, newp)
1089+
end
10861090
return newu0, newp
10871091
end
10881092

@@ -1094,10 +1098,6 @@ Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after
10941098
"""
10951099
function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp)
10961100
root_indp = prob
1097-
while hasmethod(symbolic_container, Tuple{typeof(root_indp)}) &&
1098-
(sc = symbolic_container(root_indp)) !== root_indp
1099-
root_indp = sc
1100-
end
11011101
return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
11021102
end
11031103

src/solutions/ode_solutions.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
287287
ps = parameter_values(discs)
288288
for ts_idx in eachindex(discs)
289289
partition = discs[ts_idx]
290-
interp_val = ConstantInterpolation(partition.t, partition.u)(
291-
t, nothing, deriv, nothing, continuity)
290+
interp_val = _hold_discrete(partition.u, partition.t, t)
292291
ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val)
293292
end
294293
end
@@ -312,8 +311,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
312311
ps = parameter_values(discs)
313312
for ts_idx in eachindex(discs)
314313
partition = discs[ts_idx]
315-
interp_val = ConstantInterpolation(partition.t, partition.u)(
316-
t, nothing, deriv, nothing, continuity)
314+
interp_val = _hold_discrete(partition.u, partition.t, t)
317315
ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val)
318316
end
319317
end

test/downstream/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ DelayDiffEq = "5"
3434
DiffEqCallbacks = "3, 4"
3535
ForwardDiff = "0.10"
3636
JumpProcesses = "9.10"
37-
ModelingToolkit = "9.56"
37+
ModelingToolkit = "9.64.1"
3838
ModelingToolkitStandardLibrary = "2.7"
3939
NonlinearSolve = "2, 3, 4"
4040
Optimization = "4"

test/downstream/solution_interface.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ end
215215
@test sol[x] == xvals
216216
@test is_parameter(sol, p)
217217
@test parameter_index(sol, p) == parameter_index(sys, p)
218-
@test isequal(only(parameter_symbols(sol)), p)
218+
@test any(isequal(p), parameter_symbols(sol))
219219
@test is_independent_variable(sol, t)
220220

221221
tmp = copy(prob.u0)
@@ -341,3 +341,12 @@ end
341341
@test _ss isa SciMLBase.SavedSubsystem
342342
end
343343
end
344+
345+
@testset "Interpolation after final discrete save" begin
346+
@variables x(t) y(t)
347+
@parameters start
348+
@mtkbuild sys=ODESystem([D(x) ~ y, y ~ ifelse(t < start, 1.0, 2.0)], t) additional_passes=[ModelingToolkit.IfLifting]
349+
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [start => 0.5])
350+
sol = solve(prob)
351+
@test sol(0.6, idxs = y) 2.0
352+
end

0 commit comments

Comments
 (0)