Skip to content

Commit a626b86

Browse files
feat: support symbolic indexing of a subset of the system
1 parent 3c1211a commit a626b86

File tree

5 files changed

+534
-12
lines changed

5 files changed

+534
-12
lines changed

src/SciMLBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ include("problems/problem_interface.jl")
724724
include("problems/optimization_problems.jl")
725725

726726
include("clock.jl")
727+
include("solutions/save_idxs.jl")
727728
include("solutions/basic_solutions.jl")
728729
include("solutions/nonlinear_solutions.jl")
729730
include("solutions/ode_solutions.jl")

src/solutions/ode_solutions.jl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
104104
successfully, whether it terminated early due to a user-defined callback, or whether it
105105
exited due to an error. For more details, see
106106
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
107+
- `saved_subsystem`: a [`SavedSubsystem`](@ref) representing the subset of variables saved
108+
in this solution, or `nothing` if all variables are saved. Here "variables" refers to
109+
both continuous-time state variables and timeseries parameters.
107110
"""
108111
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
109-
AC <: Union{Nothing, Vector{Int}}, R, O} <:
112+
AC <: Union{Nothing, Vector{Int}}, R, O, V} <:
110113
AbstractODESolution{T, N, uType}
111114
u::uType
112115
u_analytic::uType2
@@ -124,6 +127,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A,
124127
retcode::ReturnCode.T
125128
resid::R
126129
original::O
130+
saved_subsystem::V
127131
end
128132

129133
function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}}
@@ -137,7 +141,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
137141
patch = merge(getproperties(sol), patch)
138142
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
139143
patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
140-
patch.alg_choice, patch.retcode, patch.resid, patch.original)
144+
patch.alg_choice, patch.retcode, patch.resid, patch.original, patch.saved_subsystem)
141145
end
142146

143147
Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol)
@@ -154,12 +158,12 @@ end
154158
function ODESolution{T, N}(
155159
u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
156160
tslocation, stats, alg_choice, retcode, resid = nothing,
157-
original = nothing) where {T, N}
161+
original = nothing, saved_subsystem = nothing) where {T, N}
158162
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
159163
typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
160-
typeof(stats), typeof(alg_choice), typeof(resid),
161-
typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
162-
dense, tslocation, stats, alg_choice, retcode, resid, original)
164+
typeof(stats), typeof(alg_choice), typeof(resid), typeof(original),
165+
typeof(saved_subsystem)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
166+
dense, tslocation, stats, alg_choice, retcode, resid, original, saved_subsystem)
163167
end
164168

165169
error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
@@ -409,15 +413,25 @@ const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: Abstrac
409413
# public API, used by MTK
410414
"""
411415
get_saveable_values(sys, ps, timeseries_idx)
416+
417+
Return the values to be saved in parameter object `ps` for timeseries index `timeseries_idx`. Called by
418+
`save_discretes!`. If this returns `nothing`, `save_discretes!` will not save anything.
412419
"""
413420
function get_saveable_values(sys, ps, timeseries_idx)
414421
return get_saveable_values(symbolic_container(sys), ps, timeseries_idx)
415422
end
416423

424+
"""
425+
save_discretes!(integ::DEIntegrator, timeseries_idx)
426+
427+
Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
428+
get the values to save. If it returns `nothing`, then the save does not happen.
429+
"""
417430
function save_discretes!(integ::DEIntegrator, timeseries_idx)
418-
save_discretes!(integ.sol, current_time(integ),
419-
get_saveable_values(integ, parameter_values(integ), timeseries_idx),
420-
timeseries_idx)
431+
inner_sol = get_sol(integ)
432+
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
433+
vals === nothing && return
434+
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
421435
end
422436

423437
save_discretes!(args...) = nothing
@@ -451,6 +465,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
451465
interp = LinearInterpolation(t, u),
452466
retcode = ReturnCode.Default, destats = missing, stats = nothing,
453467
resid = nothing, original = nothing,
468+
saved_subsystem = nothing,
454469
kwargs...)
455470
T = eltype(eltype(u))
456471

@@ -482,7 +497,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
482497

483498
ps = parameter_values(prob)
484499
if has_sys(prob.f)
485-
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
500+
sswf = if saved_subsystem === nothing
501+
prob.f.sys
502+
else
503+
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
504+
end
505+
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
486506
else
487507
discretes = nothing
488508
end
@@ -503,7 +523,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
503523
alg_choice,
504524
retcode,
505525
resid,
506-
original)
526+
original,
527+
saved_subsystem)
507528
if calculate_error
508529
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
509530
dense_errors = dense_errors)
@@ -524,7 +545,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
524545
alg_choice,
525546
retcode,
526547
resid,
527-
original)
548+
original,
549+
saved_subsystem)
528550
end
529551
end
530552

0 commit comments

Comments
 (0)