Skip to content

Commit 7cb1379

Browse files
Merge pull request #809 from AayushSabharwal/as/symbolic-save-idxs
feat: support symbolic indexing of a subset of the system
2 parents 3c1211a + c219343 commit 7cb1379

File tree

9 files changed

+648
-159
lines changed

9 files changed

+648
-159
lines changed

src/SciMLBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ include("problems/problem_interface.jl")
724724
include("problems/optimization_problems.jl")
725725

726726
include("clock.jl")
727+
include("solutions/save_idxs.jl")
727728
include("solutions/basic_solutions.jl")
728729
include("solutions/nonlinear_solutions.jl")
729730
include("solutions/ode_solutions.jl")

src/integrator_interface.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -331,19 +331,22 @@ Otherwise the integrator is allowed to skip recalculating the interpolation.
331331
332332
# Arguments
333333
334-
- `continuous_modification`: determines whether the modification is due to a continuous change (continuous callback)
335-
or a discrete callback. For a continuous change, this can include a change to time which requires a re-evaluation
336-
of the interpolations.
337-
- `callback_initializealg`: the initialization algorithm provided by the callback. For DAEs, this is the choice for the
338-
initialization that is done post callback. The default value of `nothing` means that the initialization choice
339-
used for the DAE should be performed post-callback.
334+
- `continuous_modification`: determines whether the modification is due to a continuous change (continuous callback)
335+
or a discrete callback. For a continuous change, this can include a change to time which requires a re-evaluation
336+
of the interpolations.
337+
- `callback_initializealg`: the initialization algorithm provided by the callback. For DAEs, this is the choice for the
338+
initialization that is done post callback. The default value of `nothing` means that the initialization choice
339+
used for the DAE should be performed post-callback.
340340
"""
341341
function reeval_internals_due_to_modification!(
342342
integrator::DEIntegrator, continuous_modification;
343343
callback_initializealg = nothing)
344344
reeval_internals_due_to_modification!(integrator::DEIntegrator)
345345
end
346-
reeval_internals_due_to_modification!(integrator::DEIntegrator; callback_initializealg = nothing) = nothing
346+
function reeval_internals_due_to_modification!(
347+
integrator::DEIntegrator; callback_initializealg = nothing)
348+
nothing
349+
end
347350

348351
"""
349352
set_t!(integrator::DEIntegrator, t)

src/scimlfunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2693,8 +2693,8 @@ function SplitFunction{iip, specialize}(f1, f2;
26932693
f1.jac_prototype :
26942694
nothing,
26952695
W_prototype = __has_W_prototype(f1) ?
2696-
f1.W_prototype :
2697-
nothing,
2696+
f1.W_prototype :
2697+
nothing,
26982698
sparsity = __has_sparsity(f1) ? f1.sparsity :
26992699
jac_prototype,
27002700
Wfact = __has_Wfact(f1) ? f1.Wfact : nothing,

src/solutions/dae_solutions.jl

Lines changed: 46 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
2727
exited due to an error. For more details, see
2828
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
2929
"""
30-
struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateType} <:
30+
struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateType, V} <:
3131
AbstractDAESolution{T, N, uType}
3232
u::uType
3333
du::duType
@@ -42,6 +42,31 @@ struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, S, rateT
4242
tslocation::Int
4343
stats::S
4444
retcode::ReturnCode.T
45+
saved_subsystem::V
46+
end
47+
48+
function DAESolution{T, N}(u, du, u_analytic, errors, t, k, prob, alg, interp, dense,
49+
tslocation, stats, retcode, saved_subsystem) where {T, N}
50+
return DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors),
51+
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
52+
typeof(saved_subsystem)}(
53+
u, du, u_analytic, errors, t, k, prob, alg, interp, dense, tslocation, stats,
54+
retcode, saved_subsystem
55+
)
56+
end
57+
58+
function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: DAESolution{T, N}}
59+
DAESolution{T, N}
60+
end
61+
62+
function ConstructionBase.setproperties(sol::DAESolution, patch::NamedTuple)
63+
u = get(patch, :u, sol.u)
64+
N = u === nothing ? 2 : ndims(eltype(u)) + 1
65+
T = eltype(eltype(u))
66+
patch = merge(getproperties(sol), patch)
67+
return DAESolution{T, N}(patch.u, patch.du, patch.u_analytic, patch.errors, patch.t,
68+
patch.k, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation,
69+
patch.stats, patch.retcode, patch.saved_subsystem)
4570
end
4671

4772
Base.@propagate_inbounds function Base.getproperty(x::AbstractDAESolution, s::Symbol)
@@ -65,13 +90,14 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
6590
retcode = ReturnCode.Default,
6691
destats = missing,
6792
stats = nothing,
93+
saved_subsystem = nothing,
6894
kwargs...)
6995
T = eltype(eltype(u))
7096

7197
if prob.u0 === nothing
7298
N = 2
7399
else
74-
N = length((size(prob.u0)..., length(u)))
100+
N = ndims(eltype(u)) + 1
75101
end
76102

77103
if !ismissing(destats)
@@ -88,7 +114,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
88114
errors = Dict{Symbol, real(eltype(prob.u0))}()
89115

90116
sol = DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors),
91-
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k)}(
117+
typeof(t), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
118+
typeof(saved_subsystem)}(
92119
u,
93120
du,
94121
u_analytic,
@@ -101,7 +128,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
101128
dense,
102129
0,
103130
stats,
104-
retcode)
131+
retcode,
132+
saved_subsystem)
105133

106134
if calculate_error
107135
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
@@ -110,15 +138,17 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing;
110138
sol
111139
else
112140
DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t),
113-
typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k)}(
141+
typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(k),
142+
typeof(saved_subsystem)}(
114143
u, du,
115144
nothing,
116145
nothing, t, k,
117146
prob, alg,
118147
interp,
119148
dense, 0,
120149
stats,
121-
retcode)
150+
retcode,
151+
saved_subsystem)
122152
end
123153
end
124154

@@ -161,76 +191,23 @@ function calculate_solution_errors!(sol::AbstractDAESolution;
161191
end
162192

163193
function build_solution(sol::AbstractDAESolution{T, N}, u_analytic, errors) where {T, N}
164-
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(u_analytic), typeof(errors),
165-
typeof(sol.t), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp),
166-
typeof(sol.stats), typeof(sol.k)}(sol.u,
167-
sol.du,
168-
u_analytic,
169-
errors,
170-
sol.t,
171-
sol.k,
172-
sol.prob,
173-
sol.alg,
174-
sol.interp,
175-
sol.dense,
176-
sol.tslocation,
177-
sol.stats,
178-
sol.retcode)
194+
@reset sol.u_analytic = u_analytic
195+
return @set sol.errors = errors
179196
end
180197

181198
function solution_new_retcode(sol::AbstractDAESolution{T, N}, retcode) where {T, N}
182-
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
183-
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
184-
typeof(sol.interp), typeof(sol.stats), typeof(sol.k)}(sol.u,
185-
sol.du,
186-
sol.u_analytic,
187-
sol.errors,
188-
sol.t,
189-
sol.k,
190-
sol.prob,
191-
sol.alg,
192-
sol.interp,
193-
sol.dense,
194-
sol.tslocation,
195-
sol.stats,
196-
retcode)
199+
return @set sol.retcode = retcode
197200
end
198201

199202
function solution_new_tslocation(sol::AbstractDAESolution{T, N}, tslocation) where {T, N}
200-
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
201-
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
202-
typeof(sol.interp), typeof(sol.stats), typeof(k)}(sol.u,
203-
sol.du,
204-
sol.u_analytic,
205-
sol.errors,
206-
sol.t,
207-
sol.k,
208-
sol.prob,
209-
sol.alg,
210-
sol.interp,
211-
sol.dense,
212-
tslocation,
213-
sol.stats,
214-
sol.retcode)
203+
return @set sol.tslocation = tslocation
215204
end
216205

217206
function solution_slice(sol::AbstractDAESolution{T, N}, I) where {T, N}
218-
DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic),
219-
typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg),
220-
typeof(sol.interp), typeof(sol.stats), typeof(sol.k)}(sol.u[I],
221-
sol.du[I],
222-
sol.u_analytic ===
223-
nothing ?
224-
nothing :
225-
sol.u_analytic[I],
226-
sol.errors,
227-
sol.t[I],
228-
sol.k[I],
229-
sol.prob,
230-
sol.alg,
231-
sol.interp,
232-
false,
233-
sol.tslocation,
234-
sol.stats,
235-
sol.retcode)
207+
@reset sol.u = sol.u[I]
208+
@reset sol.du = sol.du[I]
209+
@reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I]
210+
@reset sol.t = sol.t[I]
211+
@reset sol.k = sol.dense ? sol.k[I] : sol.k
212+
return @set sol.dense = false
236213
end

src/solutions/ode_solutions.jl

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
104104
successfully, whether it terminated early due to a user-defined callback, or whether it
105105
exited due to an error. For more details, see
106106
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
107+
- `saved_subsystem`: a [`SavedSubsystem`](@ref) representing the subset of variables saved
108+
in this solution, or `nothing` if all variables are saved. Here "variables" refers to
109+
both continuous-time state variables and timeseries parameters.
107110
"""
108111
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
109-
AC <: Union{Nothing, Vector{Int}}, R, O} <:
112+
AC <: Union{Nothing, Vector{Int}}, R, O, V} <:
110113
AbstractODESolution{T, N, uType}
111114
u::uType
112115
u_analytic::uType2
@@ -124,6 +127,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A,
124127
retcode::ReturnCode.T
125128
resid::R
126129
original::O
130+
saved_subsystem::V
127131
end
128132

129133
function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}}
@@ -137,7 +141,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple)
137141
patch = merge(getproperties(sol), patch)
138142
return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k,
139143
patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
140-
patch.alg_choice, patch.retcode, patch.resid, patch.original)
144+
patch.alg_choice, patch.retcode, patch.resid, patch.original, patch.saved_subsystem)
141145
end
142146

143147
Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol)
@@ -154,12 +158,12 @@ end
154158
function ODESolution{T, N}(
155159
u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
156160
tslocation, stats, alg_choice, retcode, resid = nothing,
157-
original = nothing) where {T, N}
161+
original = nothing, saved_subsystem = nothing) where {T, N}
158162
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
159163
typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
160-
typeof(stats), typeof(alg_choice), typeof(resid),
161-
typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
162-
dense, tslocation, stats, alg_choice, retcode, resid, original)
164+
typeof(stats), typeof(alg_choice), typeof(resid), typeof(original),
165+
typeof(saved_subsystem)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
166+
dense, tslocation, stats, alg_choice, retcode, resid, original, saved_subsystem)
163167
end
164168

165169
error_if_observed_derivative(_, _, ::Type{Val{0}}) = nothing
@@ -409,15 +413,25 @@ const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: Abstrac
409413
# public API, used by MTK
410414
"""
411415
get_saveable_values(sys, ps, timeseries_idx)
416+
417+
Return the values to be saved in parameter object `ps` for timeseries index `timeseries_idx`. Called by
418+
`save_discretes!`. If this returns `nothing`, `save_discretes!` will not save anything.
412419
"""
413420
function get_saveable_values(sys, ps, timeseries_idx)
414421
return get_saveable_values(symbolic_container(sys), ps, timeseries_idx)
415422
end
416423

424+
"""
425+
save_discretes!(integ::DEIntegrator, timeseries_idx)
426+
427+
Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
428+
get the values to save. If it returns `nothing`, then the save does not happen.
429+
"""
417430
function save_discretes!(integ::DEIntegrator, timeseries_idx)
418-
save_discretes!(integ.sol, current_time(integ),
419-
get_saveable_values(integ, parameter_values(integ), timeseries_idx),
420-
timeseries_idx)
431+
inner_sol = get_sol(integ)
432+
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
433+
vals === nothing && return
434+
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
421435
end
422436

423437
save_discretes!(args...) = nothing
@@ -451,6 +465,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
451465
interp = LinearInterpolation(t, u),
452466
retcode = ReturnCode.Default, destats = missing, stats = nothing,
453467
resid = nothing, original = nothing,
468+
saved_subsystem = nothing,
454469
kwargs...)
455470
T = eltype(eltype(u))
456471

@@ -482,7 +497,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
482497

483498
ps = parameter_values(prob)
484499
if has_sys(prob.f)
485-
discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan)
500+
sswf = if saved_subsystem === nothing
501+
prob.f.sys
502+
else
503+
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
504+
end
505+
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
486506
else
487507
discretes = nothing
488508
end
@@ -503,7 +523,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
503523
alg_choice,
504524
retcode,
505525
resid,
506-
original)
526+
original,
527+
saved_subsystem)
507528
if calculate_error
508529
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
509530
dense_errors = dense_errors)
@@ -524,7 +545,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
524545
alg_choice,
525546
retcode,
526547
resid,
527-
original)
548+
original,
549+
saved_subsystem)
528550
end
529551
end
530552

@@ -593,7 +615,7 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
593615
@reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I]
594616
@reset sol.t = sol.t[I]
595617
@reset sol.k = sol.dense ? sol.k[I] : sol.k
596-
return @set sol.alg = false
618+
return @set sol.dense = false
597619
end
598620

599621
mask_discretes(::Nothing, _, _...) = nothing

0 commit comments

Comments
 (0)