@@ -180,22 +180,45 @@ function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
180180 Timeseries ()
181181end
182182
183+ function _hold_discrete (disc_u, disc_t, t:: Number )
184+ idx = searchsortedlast (disc_t, t)
185+ if idx == firstindex (disc_t) - 1
186+ error (" Cannot access discrete variable at time $t before initial save $(first (disc_t)) " )
187+ end
188+ return disc_u[idx]
189+ end
190+
191+ function hold_discrete (disc_u, disc_t, t:: Number )
192+ val = _hold_discrete (disc_u, disc_t, t)
193+ return DiffEqArray ([val], [t])
194+ end
195+
196+ function hold_discrete (disc_u, disc_t, t:: AbstractVector{<:Number} )
197+ return DiffEqArray (_hold_discrete .((disc_u,), (disc_t,), t), t)
198+ end
199+
183200function get_interpolated_discretes (sol:: AbstractODESolution , t, deriv, continuity)
184201 is_parameter_timeseries (sol) == Timeseries () || return nothing
185202
186203 discs:: ParameterTimeseriesCollection = RecursiveArrayTools. get_discretes (sol)
187204 interp_discs = map (discs) do partition
188- ConstantInterpolation (partition. t , partition. u)( t, nothing , deriv, nothing , continuity )
205+ hold_discrete (partition. u , partition. t, t )
189206 end
190207 return ParameterTimeseriesCollection (interp_discs, parameter_values (discs))
191208end
192209
193210function (sol:: AbstractODESolution )(t, :: Type{deriv} = Val{0 }; idxs = nothing ,
194211 continuity = :left ) where {deriv}
212+ if t isa IndexedClock
213+ t = canonicalize_indexed_clock (t, sol)
214+ end
195215 sol (t, deriv, idxs, continuity)
196216end
197217function (sol:: AbstractODESolution )(v, t, :: Type{deriv} = Val{0 }; idxs = nothing ,
198218 continuity = :left ) where {deriv}
219+ if t isa IndexedClock
220+ t = canonicalize_indexed_clock (t, sol)
221+ end
199222 sol. interp (v, t, idxs, deriv, sol. prob. p, continuity)
200223end
201224
@@ -247,15 +270,13 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
247270 if is_parameter (sol, idxs) && ! is_timeseries_parameter (sol, idxs)
248271 return getp (sol, idxs)(ps)
249272 end
250- # NOTE: This is basically SII.parameter_values_at_time but that isn't public API
251- # and once we move interpolation to SII, there's no reason for it to be
252273 if is_parameter_timeseries (sol) == Timeseries ()
253274 discs:: ParameterTimeseriesCollection = RecursiveArrayTools. get_discretes (sol)
254275 ps = parameter_values (discs)
255276 for ts_idx in eachindex (discs)
256277 partition = discs[ts_idx]
257278 interp_val = ConstantInterpolation (partition. t, partition. u)(t, nothing , deriv, nothing , continuity)
258- ps = with_updated_parameter_timeseries_values (ps, ts_idx => interp_val)
279+ ps = with_updated_parameter_timeseries_values (sol, ps, ts_idx => interp_val)
259280 end
260281 end
261282 state = ProblemState (; u = sol. interp (t, nothing , deriv, ps, continuity), p = ps, t)
@@ -270,15 +291,13 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
270291 end
271292 error_if_observed_derivative (sol, idxs, deriv)
272293 ps = parameter_values (sol)
273- # NOTE: This is basically SII.parameter_values_at_time but that isn't public API
274- # and once we move interpolation to SII, there's no reason for it to be
275294 if is_parameter_timeseries (sol) == Timeseries ()
276295 discs:: ParameterTimeseriesCollection = RecursiveArrayTools. get_discretes (sol)
277296 ps = parameter_values (discs)
278297 for ts_idx in eachindex (discs)
279298 partition = discs[ts_idx]
280299 interp_val = ConstantInterpolation (partition. t, partition. u)(t, nothing , deriv, nothing , continuity)
281- ps = with_updated_parameter_timeseries_values (ps, ts_idx => interp_val)
300+ ps = with_updated_parameter_timeseries_values (sol, ps, ts_idx => interp_val)
282301 end
283302 end
284303 state = ProblemState (; u = sol. interp (t, nothing , deriv, ps, continuity), p = ps, t)
@@ -290,9 +309,21 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
290309 symbolic_type (idxs) == NotSymbolic () && error (" Incorrect specification of `idxs`" )
291310 error_if_observed_derivative (sol, idxs, deriv)
292311 p = hasproperty (sol. prob, :p ) ? sol. prob. p : nothing
312+ getter = getu (sol, idxs)
313+ if is_parameter_timeseries (sol) == NotTimeseries ()
314+ interp_sol = augment (sol. interp (t, nothing , deriv, p, continuity), sol)
315+ return DiffEqArray (getter (interp_sol), t, p, sol)
316+ end
293317 discretes = get_interpolated_discretes (sol, t, deriv, continuity)
294- interp_sol = augment (sol. interp (t, nothing , deriv, p, continuity), sol; discretes)
295- return DiffEqArray (getu (interp_sol, idxs)(interp_sol), t, p, sol; discretes)
318+ interp_sol = sol. interp (t, nothing , deriv, p, continuity)
319+ u = map (eachindex (t)) do ti
320+ ps = parameter_values (discretes)
321+ for i in eachindex (discretes)
322+ ps = with_updated_parameter_timeseries_values (sol, ps, i => discretes[i, ti])
323+ end
324+ return getter (ProblemState (; u = interp_sol. u[ti], p = ps, t = t[ti]))
325+ end
326+ return DiffEqArray (u, t, p, sol; discretes)
296327end
297328
298329function (sol:: AbstractODESolution )(t:: AbstractVector{<:Number} , :: Type{deriv} ,
@@ -301,34 +332,51 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
301332 error (" Incorrect specification of `idxs`" )
302333 error_if_observed_derivative (sol, idxs, deriv)
303334 p = hasproperty (sol. prob, :p ) ? sol. prob. p : nothing
335+ getter = getu (sol, idxs)
336+ if is_parameter_timeseries (sol) == NotTimeseries ()
337+ interp_sol = augment (sol. interp (t, nothing , deriv, p, continuity), sol)
338+ return DiffEqArray (getter (interp_sol), t, p, sol)
339+ end
304340 discretes = get_interpolated_discretes (sol, t, deriv, continuity)
305- interp_sol = augment (sol. interp (t, nothing , deriv, p, continuity), sol; discretes)
306- return DiffEqArray (
307- getu (interp_sol, idxs)(interp_sol), t, p, sol; discretes)
341+ interp_sol = sol. interp (t, nothing , deriv, p, continuity)
342+ u = map (eachindex (t)) do ti
343+ ps = parameter_values (discretes)
344+ for i in eachindex (discretes)
345+ ps = with_updated_parameter_timeseries_values (sol, ps, i => discretes[i, ti])
346+ end
347+ return getter (ProblemState (; u = interp_sol. u[ti], p = ps, t = t[ti]))
348+ end
349+ return DiffEqArray (u, t, p, sol; discretes)
308350end
309351
310352# public API, used by MTK
311353"""
312- create_parameter_timeseries_collection(sys, ps)
354+ create_parameter_timeseries_collection(sys, ps, tspan )
313355
314356Create a `SymbolicIndexingInterface.ParameterTimeseriesCollection` for the given system
315357`sys` and parameter object `ps`. Return `nothing` if there are no timeseries parameters.
316- Defaults to `nothing`.
358+ Defaults to `nothing`. Falls back on the basis of `symbolic_container`.
317359"""
318360function create_parameter_timeseries_collection (sys, ps, tspan)
319- return nothing
361+ if hasmethod (symbolic_container, Tuple{typeof (sys)})
362+ return create_parameter_timeseries_collection (symbolic_container (sys), ps, tspan)
363+ else
364+ return nothing
365+ end
320366end
321367
322368const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: AbstractRange }
323369
324370# public API, used by MTK
325371"""
326- get_saveable_values(ps, timeseries_idx)
372+ get_saveable_values(sys, ps, timeseries_idx)
327373"""
328- function get_saveable_values end
374+ function get_saveable_values (sys, ps, timeseries_idx)
375+ return get_saveable_values (symbolic_container (sys), ps, timeseries_idx)
376+ end
329377
330378function save_discretes! (integ:: DEIntegrator , timeseries_idx)
331- save_discretes! (integ. sol, current_time (integ), get_saveable_values (parameter_values (integ), timeseries_idx), timeseries_idx)
379+ save_discretes! (integ. sol, current_time (integ), get_saveable_values (integ, parameter_values (integ), timeseries_idx), timeseries_idx)
332380end
333381
334382save_discretes! (args... ) = nothing
@@ -346,9 +394,8 @@ function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals)
346394end
347395
348396function _save_discretes_internal! (A:: PeriodicDiffEqArray , t, vals)
349- # This is O(1) because A.t is a range
350- idx = searchsortedlast (A. t, t)
351- if idx == firstindex (A. t) - 1 || A. t[idx] ≉ t
397+ idx = length (A. u) + 1
398+ if A. t[idx] ≉ t
352399 error (" Tried to save periodic discrete value with timeseries $(A. t) at time $t " )
353400 end
354401 push! (A. u, vals)
@@ -393,7 +440,11 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
393440 end
394441
395442 ps = parameter_values (prob)
396- discretes = create_parameter_timeseries_collection (prob. f. sys, ps, prob. tspan)
443+ if has_sys (prob. f)
444+ discretes = create_parameter_timeseries_collection (prob. f. sys, ps, prob. tspan)
445+ else
446+ discretes = nothing
447+ end
397448 if has_analytic (f)
398449 u_analytic = Vector {typeof(prob.u0)} ()
399450 errors = Dict {Symbol, real(eltype(prob.u0))} ()
0 commit comments