@@ -105,14 +105,15 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
105105 exited due to an error. For more details, see
106106 [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
107107"""
108- struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S,
108+ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
109109 AC <: Union{Nothing, Vector{Int}} , R, O} < :
110110 AbstractODESolution{T, N, uType}
111111 u:: uType
112112 u_analytic:: uType2
113113 errors:: DType
114114 t:: tType
115115 k:: rateType
116+ discretes:: discType
116117 prob:: P
117118 alg:: A
118119 interp:: IType
@@ -135,7 +136,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
135136 T = eltype (eltype (u))
136137 patch = merge (getproperties (sol), patch)
137138 return ODESolution {T, N} (patch. u, patch. u_analytic, patch. errors, patch. t, patch. k,
138- patch. prob, patch. alg, patch. interp, patch. dense, patch. tslocation, patch. stats,
139+ patch. discretes, patch . prob, patch. alg, patch. interp, patch. dense, patch. tslocation, patch. stats,
139140 patch. alg_choice, patch. retcode, patch. resid, patch. original)
140141end
141142
@@ -150,13 +151,14 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy
150151end
151152
152153# FIXME : Remove the defaults for resid and original on a breaking release
153- function ODESolution {T, N} (u, u_analytic, errors, t, k, prob, alg, interp, dense,
154+ function ODESolution {T, N} (
155+ u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
154156 tslocation, stats, alg_choice, retcode, resid = nothing ,
155157 original = nothing ) where {T, N}
156158 return ODESolution{T, N, typeof (u), typeof (u_analytic), typeof (errors), typeof (t),
157- typeof (k), typeof (prob), typeof (alg), typeof (interp),
159+ typeof (k), typeof (discretes), typeof ( prob), typeof (alg), typeof (interp),
158160 typeof (stats), typeof (alg_choice), typeof (resid),
159- typeof (original)}(u, u_analytic, errors, t, k, prob, alg, interp,
161+ typeof (original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
160162 dense, tslocation, stats, alg_choice, retcode, resid, original)
161163end
162164
@@ -172,6 +174,22 @@ function error_if_observed_derivative(sys, idx, ::Type)
172174 end
173175end
174176
177+ function SymbolicIndexingInterface. is_parameter_timeseries (:: Type{S} ) where {
178+ T1, T2, T3, T4, T5, T6, T7,
179+ S <: ODESolution{T1, T2, T3, T4, T5, T6, T7, <:ParameterTimeseriesCollection} }
180+ Timeseries ()
181+ end
182+
183+ function get_interpolated_discretes (sol:: AbstractODESolution , t, deriv, continuity)
184+ is_parameter_timeseries (sol) == Timeseries () || return nothing
185+
186+ discs:: ParameterTimeseriesCollection = RecursiveArrayTools. get_discretes (sol)
187+ interp_discs = map (discs) do partition
188+ ConstantInterpolation (partition. t, partition. u)(t, nothing , deriv, nothing , continuity)
189+ end
190+ return ParameterTimeseriesCollection (interp_discs, parameter_values (discs))
191+ end
192+
175193function (sol:: AbstractODESolution )(t, :: Type{deriv} = Val{0 }; idxs = nothing ,
176194 continuity = :left ) where {deriv}
177195 sol (t, deriv, idxs, continuity)
188206
189207function (sol:: AbstractODESolution )(t:: AbstractVector{<:Number} , :: Type{deriv} ,
190208 idxs:: Nothing , continuity) where {deriv}
191- augment (sol. interp (t, idxs, deriv, sol. prob. p, continuity), sol)
209+ discretes = get_interpolated_discretes (sol, t, deriv, continuity)
210+ augment (sol. interp (t, idxs, deriv, sol. prob. p, continuity), sol; discretes)
192211end
193212
194213function (sol:: AbstractODESolution )(t:: Number , :: Type{deriv} , idxs:: Integer ,
@@ -224,11 +243,23 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
224243 continuity) where {deriv}
225244 symbolic_type (idxs) == NotSymbolic () && error (" Incorrect specification of `idxs`" )
226245 error_if_observed_derivative (sol, idxs, deriv)
227- if is_parameter (sol, idxs)
228- return getp (sol, idxs)(sol)
229- else
230- return augment (sol. interp ([t], nothing , deriv, sol. prob. p, continuity), sol)[idxs][1 ]
246+ ps = parameter_values (sol)
247+ if is_parameter (sol, idxs) && ! is_timeseries_parameter (sol, idxs)
248+ return getp (sol, idxs)(ps)
249+ end
250+ # NOTE: This is basically SII.parameter_values_at_time but that isn't public API
251+ # and once we move interpolation to SII, there's no reason for it to be
252+ if is_parameter_timeseries (sol) == Timeseries ()
253+ discs:: ParameterTimeseriesCollection = RecursiveArrayTools. get_discretes (sol)
254+ ps = parameter_values (discs)
255+ for ts_idx in eachindex (discs)
256+ partition = discs[ts_idx]
257+ interp_val = ConstantInterpolation (partition. t, partition. u)(t, nothing , deriv, nothing , continuity)
258+ ps = with_updated_parameter_timeseries_values (ps, ts_idx => interp_val)
259+ end
231260 end
261+ state = ProblemState (; u = sol. interp (t, nothing , deriv, ps, continuity), p = ps, t)
262+ return getu (sol, idxs)(state)
232263end
233264
234265function (sol:: AbstractODESolution )(t:: Number , :: Type{deriv} , idxs:: AbstractVector ,
@@ -238,33 +269,89 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
238269 error (" Incorrect specification of `idxs`" )
239270 end
240271 error_if_observed_derivative (sol, idxs, deriv)
241- interp_sol = augment (sol. interp ([t], nothing , deriv, sol. prob. p, continuity), sol)
242- first (interp_sol[idxs])
272+ ps = parameter_values (sol)
273+ # NOTE: This is basically SII.parameter_values_at_time but that isn't public API
274+ # and once we move interpolation to SII, there's no reason for it to be
275+ if is_parameter_timeseries (sol) == Timeseries ()
276+ discs:: ParameterTimeseriesCollection = RecursiveArrayTools. get_discretes (sol)
277+ ps = parameter_values (discs)
278+ for ts_idx in eachindex (discs)
279+ partition = discs[ts_idx]
280+ interp_val = ConstantInterpolation (partition. t, partition. u)(t, nothing , deriv, nothing , continuity)
281+ ps = with_updated_parameter_timeseries_values (ps, ts_idx => interp_val)
282+ end
283+ end
284+ state = ProblemState (; u = sol. interp (t, nothing , deriv, ps, continuity), p = ps, t)
285+ return getu (sol, idxs)(state)
243286end
244287
245288function (sol:: AbstractODESolution )(t:: AbstractVector{<:Number} , :: Type{deriv} , idxs,
246289 continuity) where {deriv}
247290 symbolic_type (idxs) == NotSymbolic () && error (" Incorrect specification of `idxs`" )
248291 error_if_observed_derivative (sol, idxs, deriv)
249- if is_parameter (sol, idxs)
250- return getp (sol, idxs)(sol)
251- else
252- interp_sol = augment (sol. interp (t, nothing , deriv, sol. prob. p, continuity), sol)
253- p = hasproperty (sol. prob, :p ) ? sol. prob. p : nothing
254- return DiffEqArray (interp_sol[idxs], t, p, sol)
255- end
292+ p = hasproperty (sol. prob, :p ) ? sol. prob. p : nothing
293+ discretes = get_interpolated_discretes (sol, t, deriv, continuity)
294+ interp_sol = augment (sol. interp (t, nothing , deriv, p, continuity), sol; discretes)
295+ return DiffEqArray (getu (interp_sol, idxs)(interp_sol), t, p, sol; discretes)
256296end
257297
258298function (sol:: AbstractODESolution )(t:: AbstractVector{<:Number} , :: Type{deriv} ,
259299 idxs:: AbstractVector , continuity) where {deriv}
260300 all (! isequal (NotSymbolic ()), symbolic_type .(idxs)) ||
261301 error (" Incorrect specification of `idxs`" )
262302 error_if_observed_derivative (sol, idxs, deriv)
263- interp_sol = augment (sol. interp (t, nothing , deriv, sol. prob. p, continuity), sol)
264303 p = hasproperty (sol. prob, :p ) ? sol. prob. p : nothing
265- indexed_sol = interp_sol[idxs]
304+ discretes = get_interpolated_discretes (sol, t, deriv, continuity)
305+ interp_sol = augment (sol. interp (t, nothing , deriv, p, continuity), sol; discretes)
266306 return DiffEqArray (
267- [indexed_sol[i] for i in 1 : length (t)], t, p, sol)
307+ getu (interp_sol, idxs)(interp_sol), t, p, sol; discretes)
308+ end
309+
310+ # public API, used by MTK
311+ """
312+ create_parameter_timeseries_collection(sys, ps)
313+
314+ Create a `SymbolicIndexingInterface.ParameterTimeseriesCollection` for the given system
315+ `sys` and parameter object `ps`. Return `nothing` if there are no timeseries parameters.
316+ Defaults to `nothing`.
317+ """
318+ function create_parameter_timeseries_collection (sys, ps, tspan)
319+ return nothing
320+ end
321+
322+ const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: AbstractRange }
323+
324+ # public API, used by MTK
325+ """
326+ get_saveable_values(ps, timeseries_idx)
327+ """
328+ function get_saveable_values end
329+
330+ function save_discretes! (integ:: DEIntegrator , timeseries_idx)
331+ save_discretes! (integ. sol, current_time (integ), get_saveable_values (parameter_values (integ), timeseries_idx), timeseries_idx)
332+ end
333+
334+ save_discretes! (args... ) = nothing
335+
336+ # public API, used by MTK
337+ function save_discretes! (sol:: AbstractODESolution , t, vals, timeseries_idx)
338+ RecursiveArrayTools. has_discretes (sol) || return
339+ disc = RecursiveArrayTools. get_discretes (sol)
340+ _save_discretes_internal! (disc[timeseries_idx], t, vals)
341+ end
342+
343+ function _save_discretes_internal! (A:: AbstractDiffEqArray , t, vals)
344+ push! (A. t, t)
345+ push! (A. u, vals)
346+ end
347+
348+ function _save_discretes_internal! (A:: PeriodicDiffEqArray , t, vals)
349+ # This is O(1) because A.t is a range
350+ idx = searchsortedlast (A. t, t)
351+ if idx == firstindex (A. t) - 1 || A. t[idx] ≉ t
352+ error (" Tried to save periodic discrete value with timeseries $(A. t) at time $t " )
353+ end
354+ push! (A. u, vals)
268355end
269356
270357function build_solution (prob:: Union{AbstractODEProblem, AbstractDDEProblem} ,
@@ -305,13 +392,16 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
305392 Base. depwarn (msg, :build_solution )
306393 end
307394
395+ ps = parameter_values (prob)
396+ discretes = create_parameter_timeseries_collection (prob. f. sys, ps, prob. tspan)
308397 if has_analytic (f)
309398 u_analytic = Vector {typeof(prob.u0)} ()
310399 errors = Dict {Symbol, real(eltype(prob.u0))} ()
311400 sol = ODESolution {T, N} (u,
312401 u_analytic,
313402 errors,
314403 t, k,
404+ discretes,
315405 prob,
316406 alg,
317407 interp,
@@ -332,6 +422,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
332422 nothing ,
333423 nothing ,
334424 t, k,
425+ discretes,
335426 prob,
336427 alg,
337428 interp,
@@ -413,6 +504,36 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
413504 return @set sol. alg = false
414505end
415506
507+ mask_discretes (:: Nothing , _, _... ) = nothing
508+
509+ function mask_discretes (discretes:: ParameterTimeseriesCollection , new_t, :: Union{Int, CartesianIndex} )
510+ masked_discretes = map (discretes) do disc
511+ i = searchsortedlast (disc. t, new_t)
512+ disc[i: i]
513+ end
514+ return ParameterTimeseriesCollection (masked_discretes, parameter_values (discretes))
515+ end
516+
517+ function mask_discretes (discretes:: ParameterTimeseriesCollection , new_t, :: AbstractRange )
518+ mint, maxt = extrema (new_t)
519+ masked_discretes = map (discretes) do disc
520+ mini = searchsortedfirst (disc. t, mint)
521+ maxi = searchsortedlast (disc. t, maxt)
522+ disc[mini: maxi]
523+ end
524+ return ParameterTimeseriesCollection (masked_discretes, parameter_values (discretes))
525+ end
526+
527+ function mask_discretes (discretes:: ParameterTimeseriesCollection , new_t, _)
528+ masked_discretes = map (discretes) do disc
529+ idxs = map (new_t) do t
530+ searchsortedlast (disc. t, t)
531+ end
532+ disc[idxs]
533+ end
534+ return ParameterTimeseriesCollection (masked_discretes, parameter_values (discretes))
535+ end
536+
416537function sensitivity_solution (sol:: ODESolution , u, t)
417538 T = eltype (eltype (u))
418539
0 commit comments