Skip to content

Commit 51f554b

Browse files
refactor: refactor interpolation with indexed clocks
1 parent 5de5d1b commit 51f554b

File tree

1 file changed

+73
-22
lines changed

1 file changed

+73
-22
lines changed

src/solutions/ode_solutions.jl

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -180,22 +180,45 @@ function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
180180
Timeseries()
181181
end
182182

183+
function _hold_discrete(disc_u, disc_t, t::Number)
184+
idx = searchsortedlast(disc_t, t)
185+
if idx == firstindex(disc_t) - 1
186+
error("Cannot access discrete variable at time $t before initial save $(first(disc_t))")
187+
end
188+
return disc_u[idx]
189+
end
190+
191+
function hold_discrete(disc_u, disc_t, t::Number)
192+
val = _hold_discrete(disc_u, disc_t, t)
193+
return DiffEqArray([val], [t])
194+
end
195+
196+
function hold_discrete(disc_u, disc_t, t::AbstractVector{<:Number})
197+
return DiffEqArray(_hold_discrete.((disc_u,), (disc_t,), t), t)
198+
end
199+
183200
function get_interpolated_discretes(sol::AbstractODESolution, t, deriv, continuity)
184201
is_parameter_timeseries(sol) == Timeseries() || return nothing
185202

186203
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
187204
interp_discs = map(discs) do partition
188-
ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
205+
hold_discrete(partition.u, partition.t, t)
189206
end
190207
return ParameterTimeseriesCollection(interp_discs, parameter_values(discs))
191208
end
192209

193210
function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
194211
continuity = :left) where {deriv}
212+
if t isa IndexedClock
213+
t = canonicalize_indexed_clock(t, sol)
214+
end
195215
sol(t, deriv, idxs, continuity)
196216
end
197217
function (sol::AbstractODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing,
198218
continuity = :left) where {deriv}
219+
if t isa IndexedClock
220+
t = canonicalize_indexed_clock(t, sol)
221+
end
199222
sol.interp(v, t, idxs, deriv, sol.prob.p, continuity)
200223
end
201224

@@ -247,15 +270,13 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
247270
if is_parameter(sol, idxs) && !is_timeseries_parameter(sol, idxs)
248271
return getp(sol, idxs)(ps)
249272
end
250-
# NOTE: This is basically SII.parameter_values_at_time but that isn't public API
251-
# and once we move interpolation to SII, there's no reason for it to be
252273
if is_parameter_timeseries(sol) == Timeseries()
253274
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
254275
ps = parameter_values(discs)
255276
for ts_idx in eachindex(discs)
256277
partition = discs[ts_idx]
257278
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
258-
ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val)
279+
ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val)
259280
end
260281
end
261282
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
@@ -270,15 +291,13 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
270291
end
271292
error_if_observed_derivative(sol, idxs, deriv)
272293
ps = parameter_values(sol)
273-
# NOTE: This is basically SII.parameter_values_at_time but that isn't public API
274-
# and once we move interpolation to SII, there's no reason for it to be
275294
if is_parameter_timeseries(sol) == Timeseries()
276295
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
277296
ps = parameter_values(discs)
278297
for ts_idx in eachindex(discs)
279298
partition = discs[ts_idx]
280299
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
281-
ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val)
300+
ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val)
282301
end
283302
end
284303
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
@@ -290,9 +309,21 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
290309
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
291310
error_if_observed_derivative(sol, idxs, deriv)
292311
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
312+
getter = getu(sol, idxs)
313+
if is_parameter_timeseries(sol) == NotTimeseries()
314+
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol)
315+
return DiffEqArray(getter(interp_sol), t, p, sol)
316+
end
293317
discretes = get_interpolated_discretes(sol, t, deriv, continuity)
294-
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes)
295-
return DiffEqArray(getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes)
318+
interp_sol = sol.interp(t, nothing, deriv, p, continuity)
319+
u = map(eachindex(t)) do ti
320+
ps = parameter_values(discretes)
321+
for i in eachindex(discretes)
322+
ps = with_updated_parameter_timeseries_values(sol, ps, i => discretes[i, ti])
323+
end
324+
return getter(ProblemState(; u = interp_sol.u[ti], p = ps, t = t[ti]))
325+
end
326+
return DiffEqArray(u, t, p, sol; discretes)
296327
end
297328

298329
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
@@ -301,34 +332,51 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
301332
error("Incorrect specification of `idxs`")
302333
error_if_observed_derivative(sol, idxs, deriv)
303334
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
335+
getter = getu(sol, idxs)
336+
if is_parameter_timeseries(sol) == NotTimeseries()
337+
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol)
338+
return DiffEqArray(getter(interp_sol), t, p, sol)
339+
end
304340
discretes = get_interpolated_discretes(sol, t, deriv, continuity)
305-
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes)
306-
return DiffEqArray(
307-
getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes)
341+
interp_sol = sol.interp(t, nothing, deriv, p, continuity)
342+
u = map(eachindex(t)) do ti
343+
ps = parameter_values(discretes)
344+
for i in eachindex(discretes)
345+
ps = with_updated_parameter_timeseries_values(sol, ps, i => discretes[i, ti])
346+
end
347+
return getter(ProblemState(; u = interp_sol.u[ti], p = ps, t = t[ti]))
348+
end
349+
return DiffEqArray(u, t, p, sol; discretes)
308350
end
309351

310352
# public API, used by MTK
311353
"""
312-
create_parameter_timeseries_collection(sys, ps)
354+
create_parameter_timeseries_collection(sys, ps, tspan)
313355
314356
Create a `SymbolicIndexingInterface.ParameterTimeseriesCollection` for the given system
315357
`sys` and parameter object `ps`. Return `nothing` if there are no timeseries parameters.
316-
Defaults to `nothing`.
358+
Defaults to `nothing`. Falls back on the basis of `symbolic_container`.
317359
"""
318360
function create_parameter_timeseries_collection(sys, ps, tspan)
319-
return nothing
361+
if hasmethod(symbolic_container, Tuple{typeof(sys)})
362+
return create_parameter_timeseries_collection(symbolic_container(sys), ps, tspan)
363+
else
364+
return nothing
365+
end
320366
end
321367

322368
const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: AbstractRange}
323369

324370
# public API, used by MTK
325371
"""
326-
get_saveable_values(ps, timeseries_idx)
372+
get_saveable_values(sys, ps, timeseries_idx)
327373
"""
328-
function get_saveable_values end
374+
function get_saveable_values(sys, ps, timeseries_idx)
375+
return get_saveable_values(symbolic_container(sys), ps, timeseries_idx)
376+
end
329377

330378
function save_discretes!(integ::DEIntegrator, timeseries_idx)
331-
save_discretes!(integ.sol, current_time(integ), get_saveable_values(parameter_values(integ), timeseries_idx), timeseries_idx)
379+
save_discretes!(integ.sol, current_time(integ), get_saveable_values(integ, parameter_values(integ), timeseries_idx), timeseries_idx)
332380
end
333381

334382
save_discretes!(args...) = nothing
@@ -346,9 +394,8 @@ function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals)
346394
end
347395

348396
function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals)
349-
# This is O(1) because A.t is a range
350-
idx = searchsortedlast(A.t, t)
351-
if idx == firstindex(A.t) - 1 || A.t[idx] t
397+
idx = length(A.u) + 1
398+
if A.t[idx] t
352399
error("Tried to save periodic discrete value with timeseries $(A.t) at time $t")
353400
end
354401
push!(A.u, vals)
@@ -393,7 +440,11 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
393440
end
394441

395442
ps = parameter_values(prob)
396-
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
443+
if has_sys(prob.f)
444+
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
445+
else
446+
discretes = nothing
447+
end
397448
if has_analytic(f)
398449
u_analytic = Vector{typeof(prob.u0)}()
399450
errors = Dict{Symbol, real(eltype(prob.u0))}()

0 commit comments

Comments
 (0)