Skip to content

Commit cea0536

Browse files
feat: implement new SII discrete saving interface
1 parent 71a578d commit cea0536

File tree

2 files changed

+146
-24
lines changed

2 files changed

+146
-24
lines changed

src/solutions/ode_solutions.jl

Lines changed: 143 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,15 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
105105
exited due to an error. For more details, see
106106
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
107107
"""
108-
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S,
108+
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
109109
AC <: Union{Nothing, Vector{Int}}, R, O} <:
110110
AbstractODESolution{T, N, uType}
111111
u::uType
112112
u_analytic::uType2
113113
errors::DType
114114
t::tType
115115
k::rateType
116+
discretes::discType
116117
prob::P
117118
alg::A
118119
interp::IType
@@ -135,7 +136,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
135136
T = eltype(eltype(u))
136137
patch = merge(getproperties(sol), patch)
137138
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
138-
patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
139+
patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
139140
patch.alg_choice, patch.retcode, patch.resid, patch.original)
140141
end
141142

@@ -150,13 +151,14 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy
150151
end
151152

152153
# FIXME: Remove the defaults for resid and original on a breaking release
153-
function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense,
154+
function ODESolution{T, N}(
155+
u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
154156
tslocation, stats, alg_choice, retcode, resid = nothing,
155157
original = nothing) where {T, N}
156158
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
157-
typeof(k), typeof(prob), typeof(alg), typeof(interp),
159+
typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
158160
typeof(stats), typeof(alg_choice), typeof(resid),
159-
typeof(original)}(u, u_analytic, errors, t, k, prob, alg, interp,
161+
typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
160162
dense, tslocation, stats, alg_choice, retcode, resid, original)
161163
end
162164

@@ -172,6 +174,22 @@ function error_if_observed_derivative(sys, idx, ::Type)
172174
end
173175
end
174176

177+
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
178+
T1, T2, T3, T4, T5, T6, T7,
179+
S <: ODESolution{T1, T2, T3, T4, T5, T6, T7, <:ParameterTimeseriesCollection}}
180+
Timeseries()
181+
end
182+
183+
function get_interpolated_discretes(sol::AbstractODESolution, t, deriv, continuity)
184+
is_parameter_timeseries(sol) == Timeseries() || return nothing
185+
186+
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
187+
interp_discs = map(discs) do partition
188+
ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
189+
end
190+
return ParameterTimeseriesCollection(interp_discs, parameter_values(discs))
191+
end
192+
175193
function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
176194
continuity = :left) where {deriv}
177195
sol(t, deriv, idxs, continuity)
@@ -188,7 +206,8 @@ end
188206

189207
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
190208
idxs::Nothing, continuity) where {deriv}
191-
augment(sol.interp(t, idxs, deriv, sol.prob.p, continuity), sol)
209+
discretes = get_interpolated_discretes(sol, t, deriv, continuity)
210+
augment(sol.interp(t, idxs, deriv, sol.prob.p, continuity), sol; discretes)
192211
end
193212

194213
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::Integer,
@@ -224,11 +243,23 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
224243
continuity) where {deriv}
225244
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
226245
error_if_observed_derivative(sol, idxs, deriv)
227-
if is_parameter(sol, idxs)
228-
return getp(sol, idxs)(sol)
229-
else
230-
return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]
246+
ps = parameter_values(sol)
247+
if is_parameter(sol, idxs) && !is_timeseries_parameter(sol, idxs)
248+
return getp(sol, idxs)(ps)
249+
end
250+
# NOTE: This is basically SII.parameter_values_at_time but that isn't public API
251+
# and once we move interpolation to SII, there's no reason for it to be
252+
if is_parameter_timeseries(sol) == Timeseries()
253+
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
254+
ps = parameter_values(discs)
255+
for ts_idx in eachindex(discs)
256+
partition = discs[ts_idx]
257+
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
258+
ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val)
259+
end
231260
end
261+
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
262+
return getu(sol, idxs)(state)
232263
end
233264

234265
function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
@@ -238,33 +269,89 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
238269
error("Incorrect specification of `idxs`")
239270
end
240271
error_if_observed_derivative(sol, idxs, deriv)
241-
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
242-
first(interp_sol[idxs])
272+
ps = parameter_values(sol)
273+
# NOTE: This is basically SII.parameter_values_at_time but that isn't public API
274+
# and once we move interpolation to SII, there's no reason for it to be
275+
if is_parameter_timeseries(sol) == Timeseries()
276+
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
277+
ps = parameter_values(discs)
278+
for ts_idx in eachindex(discs)
279+
partition = discs[ts_idx]
280+
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
281+
ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val)
282+
end
283+
end
284+
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
285+
return getu(sol, idxs)(state)
243286
end
244287

245288
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
246289
continuity) where {deriv}
247290
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
248291
error_if_observed_derivative(sol, idxs, deriv)
249-
if is_parameter(sol, idxs)
250-
return getp(sol, idxs)(sol)
251-
else
252-
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
253-
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
254-
return DiffEqArray(interp_sol[idxs], t, p, sol)
255-
end
292+
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
293+
discretes = get_interpolated_discretes(sol, t, deriv, continuity)
294+
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes)
295+
return DiffEqArray(getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes)
256296
end
257297

258298
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
259299
idxs::AbstractVector, continuity) where {deriv}
260300
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
261301
error("Incorrect specification of `idxs`")
262302
error_if_observed_derivative(sol, idxs, deriv)
263-
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
264303
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
265-
indexed_sol = interp_sol[idxs]
304+
discretes = get_interpolated_discretes(sol, t, deriv, continuity)
305+
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes)
266306
return DiffEqArray(
267-
[indexed_sol[i] for i in 1:length(t)], t, p, sol)
307+
getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes)
308+
end
309+
310+
# public API, used by MTK
311+
"""
312+
create_parameter_timeseries_collection(sys, ps)
313+
314+
Create a `SymbolicIndexingInterface.ParameterTimeseriesCollection` for the given system
315+
`sys` and parameter object `ps`. Return `nothing` if there are no timeseries parameters.
316+
Defaults to `nothing`.
317+
"""
318+
function create_parameter_timeseries_collection(sys, ps, tspan)
319+
return nothing
320+
end
321+
322+
const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: AbstractRange}
323+
324+
# public API, used by MTK
325+
"""
326+
get_saveable_values(ps, timeseries_idx)
327+
"""
328+
function get_saveable_values end
329+
330+
function save_discretes!(integ::DEIntegrator, timeseries_idx)
331+
save_discretes!(integ.sol, current_time(integ), get_saveable_values(parameter_values(integ), timeseries_idx), timeseries_idx)
332+
end
333+
334+
save_discretes!(args...) = nothing
335+
336+
# public API, used by MTK
337+
function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx)
338+
RecursiveArrayTools.has_discretes(sol) || return
339+
disc = RecursiveArrayTools.get_discretes(sol)
340+
_save_discretes_internal!(disc[timeseries_idx], t, vals)
341+
end
342+
343+
function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals)
344+
push!(A.t, t)
345+
push!(A.u, vals)
346+
end
347+
348+
function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals)
349+
# This is O(1) because A.t is a range
350+
idx = searchsortedlast(A.t, t)
351+
if idx == firstindex(A.t) - 1 || A.t[idx] t
352+
error("Tried to save periodic discrete value with timeseries $(A.t) at time $t")
353+
end
354+
push!(A.u, vals)
268355
end
269356

270357
function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
@@ -305,13 +392,16 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
305392
Base.depwarn(msg, :build_solution)
306393
end
307394

395+
ps = parameter_values(prob)
396+
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
308397
if has_analytic(f)
309398
u_analytic = Vector{typeof(prob.u0)}()
310399
errors = Dict{Symbol, real(eltype(prob.u0))}()
311400
sol = ODESolution{T, N}(u,
312401
u_analytic,
313402
errors,
314403
t, k,
404+
discretes,
315405
prob,
316406
alg,
317407
interp,
@@ -332,6 +422,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
332422
nothing,
333423
nothing,
334424
t, k,
425+
discretes,
335426
prob,
336427
alg,
337428
interp,
@@ -413,6 +504,36 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
413504
return @set sol.alg = false
414505
end
415506

507+
mask_discretes(::Nothing, _, _...) = nothing
508+
509+
function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, ::Union{Int, CartesianIndex})
510+
masked_discretes = map(discretes) do disc
511+
i = searchsortedlast(disc.t, new_t)
512+
disc[i:i]
513+
end
514+
return ParameterTimeseriesCollection(masked_discretes, parameter_values(discretes))
515+
end
516+
517+
function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, ::AbstractRange)
518+
mint, maxt = extrema(new_t)
519+
masked_discretes = map(discretes) do disc
520+
mini = searchsortedfirst(disc.t, mint)
521+
maxi = searchsortedlast(disc.t, maxt)
522+
disc[mini:maxi]
523+
end
524+
return ParameterTimeseriesCollection(masked_discretes, parameter_values(discretes))
525+
end
526+
527+
function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, _)
528+
masked_discretes = map(discretes) do disc
529+
idxs = map(new_t) do t
530+
searchsortedlast(disc.t, t)
531+
end
532+
disc[idxs]
533+
end
534+
return ParameterTimeseriesCollection(masked_discretes, parameter_values(discretes))
535+
end
536+
416537
function sensitivity_solution(sol::ODESolution, u, t)
417538
T = eltype(eltype(u))
418539

src/solutions/solution_interface.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ function Base.show(io::IO, m::MIME"text/plain", A::AbstractNoTimeSolution)
2222
end
2323

2424
# For augmenting system information to enable symbol based indexing of interpolated solutions
25-
function augment(A::DiffEqArray{T, N, Q, B}, sol::AbstractODESolution) where {T, N, Q, B}
25+
function augment(A::DiffEqArray{T, N, Q, B}, sol::AbstractODESolution;
26+
discretes = nothing) where {T, N, Q, B}
2627
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
27-
return DiffEqArray(A.u, A.t, p, sol)
28+
return DiffEqArray(A.u, A.t, p, sol; discretes)
2829
end
2930

3031
# SymbolicIndexingInterface.jl

0 commit comments

Comments
 (0)