Skip to content

Commit c3e0faa

Browse files
feat: support symbolic indexing of a subset of the system
1 parent 3abb733 commit c3e0faa

File tree

4 files changed

+526
-12
lines changed

4 files changed

+526
-12
lines changed

src/SciMLBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ include("problems/problem_interface.jl")
724724
include("problems/optimization_problems.jl")
725725

726726
include("clock.jl")
727+
include("solutions/save_idxs.jl")
727728
include("solutions/basic_solutions.jl")
728729
include("solutions/nonlinear_solutions.jl")
729730
include("solutions/ode_solutions.jl")

src/solutions/ode_solutions.jl

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
104104
successfully, whether it terminated early due to a user-defined callback, or whether it
105105
exited due to an error. For more details, see
106106
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
107+
- `saved_subsystem`: a [`SavedSubsystem`](@ref) representing the subset of variables saved
108+
in this solution, or `nothing` if all variables are saved. Here "variables" refers to
109+
both continuous-time state variables and timeseries parameters.
107110
"""
108111
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
109-
AC <: Union{Nothing, Vector{Int}}, R, O} <:
112+
AC <: Union{Nothing, Vector{Int}}, R, O, V} <:
110113
AbstractODESolution{T, N, uType}
111114
u::uType
112115
u_analytic::uType2
@@ -124,6 +127,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A,
124127
retcode::ReturnCode.T
125128
resid::R
126129
original::O
130+
saved_subsystem::V
127131
end
128132

129133
function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}}
@@ -137,7 +141,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
137141
patch = merge(getproperties(sol), patch)
138142
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
139143
patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
140-
patch.alg_choice, patch.retcode, patch.resid, patch.original)
144+
patch.alg_choice, patch.retcode, patch.resid, patch.original, patch.saved_subsystem)
141145
end
142146

143147
Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol)
@@ -154,12 +158,12 @@ end
154158
function ODESolution{T, N}(
155159
u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
156160
tslocation, stats, alg_choice, retcode, resid = nothing,
157-
original = nothing) where {T, N}
161+
original = nothing, saved_subsystem = nothing) where {T, N}
158162
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
159163
typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
160-
typeof(stats), typeof(alg_choice), typeof(resid),
161-
typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
162-
dense, tslocation, stats, alg_choice, retcode, resid, original)
164+
typeof(stats), typeof(alg_choice), typeof(resid), typeof(original),
165+
typeof(saved_subsystem)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
166+
dense, tslocation, stats, alg_choice, retcode, resid, original, saved_subsystem)
163167
end
164168

165169
error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
@@ -180,6 +184,53 @@ function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
180184
Timeseries()
181185
end
182186

187+
const SolutionWithSavedSubsystem = ODESolution{T1,
188+
T2,
189+
T3,
190+
T4,
191+
T5,
192+
T6,
193+
T7,
194+
T8,
195+
T9,
196+
T10,
197+
T11,
198+
T12,
199+
T13,
200+
T14,
201+
T15,
202+
T16} where {
203+
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16 <: SavedSubsystem}
204+
205+
for method in [is_timeseries_parameter, timeseries_parameter_index,
206+
with_updated_parameter_timeseries_values, get_saveable_values]
207+
fname = nameof(method)
208+
mod = parentmodule(method)
209+
@eval function $(mod).$(fname)(sol::SolutionWithSavedSubsystem, args...)
210+
$(method)(SavedSubsystemWithFallback(sol.saved_subsystem, symbolic_container(sol)),
211+
args...)
212+
end
213+
end
214+
215+
function SymbolicIndexingInterface.state_values(sol::SolutionWithSavedSubsystem, i)
216+
original = state_values(sol.prob)
217+
saved = sol.u[i]
218+
if !(saved isa AbstractArray)
219+
saved = [saved]
220+
end
221+
ss = sol.saved_subsystem
222+
idxs = similar(saved, eltype(keys(ss.state_map)))
223+
for (k, v) in ss.state_map
224+
idxs[v] = k
225+
end
226+
replaced = remake_buffer(sol, original, idxs, saved)
227+
return replaced
228+
end
229+
230+
function SymbolicIndexingInterface.state_values(sol::SolutionWithSavedSubsystem)
231+
return map(Base.Fix1(state_values, sol), eachindex(sol.u))
232+
end
233+
183234
function _hold_discrete(disc_u, disc_t, t::Number)
184235
idx = searchsortedlast(disc_t, t)
185236
if idx == firstindex(disc_t) - 1
@@ -409,15 +460,25 @@ const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: Abstrac
409460
# public API, used by MTK
410461
"""
411462
get_saveable_values(sys, ps, timeseries_idx)
463+
464+
Return the values to be saved in parameter object `ps` for timeseries index `timeseries_idx`. Called by
465+
`save_discretes!`. If this returns `nothing`, `save_discretes!` will not save anything.
412466
"""
413467
function get_saveable_values(sys, ps, timeseries_idx)
414468
return get_saveable_values(symbolic_container(sys), ps, timeseries_idx)
415469
end
416470

471+
"""
472+
save_discretes!(integ::DEIntegrator, timeseries_idx)
473+
474+
Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
475+
get the values to save. If it returns `nothing`, then the save does not happen.
476+
"""
417477
function save_discretes!(integ::DEIntegrator, timeseries_idx)
418-
save_discretes!(integ.sol, current_time(integ),
419-
get_saveable_values(integ, parameter_values(integ), timeseries_idx),
420-
timeseries_idx)
478+
inner_sol = get_sol(integ)
479+
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
480+
vals === nothing && return
481+
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
421482
end
422483

423484
save_discretes!(args...) = nothing
@@ -451,6 +512,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
451512
interp = LinearInterpolation(t, u),
452513
retcode = ReturnCode.Default, destats = missing, stats = nothing,
453514
resid = nothing, original = nothing,
515+
saved_subsystem = nothing,
454516
kwargs...)
455517
T = eltype(eltype(u))
456518

@@ -482,7 +544,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
482544

483545
ps = parameter_values(prob)
484546
if has_sys(prob.f)
485-
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
547+
sswf = if saved_subsystem === nothing
548+
prob.f.sys
549+
else
550+
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
551+
end
552+
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
486553
else
487554
discretes = nothing
488555
end
@@ -503,7 +570,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
503570
alg_choice,
504571
retcode,
505572
resid,
506-
original)
573+
original,
574+
saved_subsystem)
507575
if calculate_error
508576
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
509577
dense_errors = dense_errors)
@@ -524,7 +592,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
524592
alg_choice,
525593
retcode,
526594
resid,
527-
original)
595+
original,
596+
saved_subsystem)
528597
end
529598
end
530599

0 commit comments

Comments
 (0)