Skip to content

Commit d516d81

Browse files
feat: support symbolic save_idxs in RODESolution
1 parent ad5a5a8 commit d516d81

File tree

1 file changed

+31
-68
lines changed

1 file changed

+31
-68
lines changed

src/solutions/rode_solutions.jl

Lines changed: 31 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
3333
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
3434
"""
3535
struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, S,
36-
AC <: Union{Nothing, Vector{Int}}} <:
36+
AC <: Union{Nothing, Vector{Int}}, V} <:
3737
AbstractRODESolution{T, N, uType}
3838
u::uType
3939
u_analytic::uType2
@@ -49,6 +49,7 @@ struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, S,
4949
alg_choice::AC
5050
retcode::ReturnCode.T
5151
seed::UInt64
52+
saved_subsystem::V
5253
end
5354

5455
function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: RODESolution{T, N}}
@@ -63,10 +64,10 @@ function ConstructionBase.setproperties(sol::RODESolution, patch::NamedTuple)
6364
return RODESolution{
6465
T, N, typeof(patch.u), typeof(patch.u_analytic), typeof(patch.errors),
6566
typeof(patch.t), typeof(patch.W), typeof(patch.prob), typeof(patch.alg), typeof(patch.interp),
66-
typeof(patch.stats), typeof(patch.alg_choice)}(
67+
typeof(patch.stats), typeof(patch.alg_choice), typeof(patch.saved_subsystem)}(
6768
patch.u, patch.u_analytic, patch.errors, patch.t, patch.W,
6869
patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
69-
patch.alg_choice, patch.retcode, patch.seed)
70+
patch.alg_choice, patch.retcode, patch.seed, patch.saved_subsystem)
7071
end
7172

7273
Base.@propagate_inbounds function Base.getproperty(x::AbstractRODESolution, s::Symbol)
@@ -94,9 +95,14 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
9495
interp = LinearInterpolation(t, u),
9596
retcode = ReturnCode.Default,
9697
alg_choice = nothing,
97-
seed = UInt64(0), destats = missing, stats = nothing, kwargs...)
98+
seed = UInt64(0), destats = missing, stats = nothing,
99+
saved_subsystem = nothing, kwargs...)
98100
T = eltype(eltype(u))
99-
N = length((size(prob.u0)..., length(u)))
101+
if prob.u0 === nothing
102+
N = 2
103+
else
104+
N = ndims(eltype(u)) + 1
105+
end
100106

101107
if prob.f isa Tuple
102108
f = prob.f[1]
@@ -120,7 +126,7 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
120126
sol = RODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
121127
typeof(W),
122128
typeof(prob), typeof(alg), typeof(interp), typeof(stats),
123-
typeof(alg_choice)}(u,
129+
typeof(alg_choice), typeof(saved_subsystem)}(u,
124130
u_analytic,
125131
errors,
126132
t, W,
@@ -132,7 +138,8 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
132138
stats,
133139
alg_choice,
134140
retcode,
135-
seed)
141+
seed,
142+
saved_subsystem)
136143

137144
if calculate_error
138145
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
@@ -143,10 +150,11 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
143150
else
144151
return RODESolution{T, N, typeof(u), Nothing, Nothing, typeof(t),
145152
typeof(W), typeof(prob), typeof(alg), typeof(interp),
146-
typeof(stats), typeof(alg_choice)}(u, nothing, nothing, t, W,
153+
typeof(stats), typeof(alg_choice), typeof(saved_subsystem)}(
154+
u, nothing, nothing, t, W,
147155
prob, alg, interp,
148156
dense, 0, stats,
149-
alg_choice, retcode, seed)
157+
alg_choice, retcode, seed, saved_subsystem)
150158
end
151159
end
152160

@@ -197,54 +205,24 @@ function calculate_solution_errors!(sol::AbstractRODESolution; fill_uanalytic =
197205
end
198206
end
199207

200-
function build_solution(sol::AbstractRODESolution{T, N}, u_analytic, errors) where {T, N}
201-
RODESolution{T, N, typeof(sol.u), typeof(u_analytic), typeof(errors), typeof(sol.t),
202-
typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp),
203-
typeof(sol.stats), typeof(sol.alg_choice)}(sol.u, u_analytic, errors,
204-
sol.t, sol.W, sol.prob,
205-
sol.alg, sol.interp,
206-
sol.dense, sol.tslocation,
207-
sol.stats, sol.alg_choice,
208-
sol.retcode, sol.seed)
208+
function build_solution(sol::AbstractRODESolution, u_analytic, errors)
209+
@reset sol.u_analytic = u_analytic
210+
return @set sol.errors = errors
209211
end
210212

211-
function solution_new_retcode(sol::AbstractRODESolution{T, N}, retcode) where {T, N}
212-
RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors),
213-
typeof(sol.t),
214-
typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp),
215-
typeof(sol.stats), typeof(sol.alg_choice)}(sol.u, sol.u_analytic,
216-
sol.errors, sol.t, sol.W,
217-
sol.prob, sol.alg, sol.interp,
218-
sol.dense, sol.tslocation,
219-
sol.stats, sol.alg_choice,
220-
retcode, sol.seed)
213+
function solution_new_retcode(sol::AbstractRODESolution, retcode)
214+
return @set sol.retcode = retcode
221215
end
222216

223-
function solution_new_tslocation(sol::AbstractRODESolution{T, N}, tslocation) where {T, N}
224-
RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors),
225-
typeof(sol.t),
226-
typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp),
227-
typeof(sol.stats), typeof(sol.alg_choice)}(sol.u, sol.u_analytic,
228-
sol.errors, sol.t, sol.W,
229-
sol.prob, sol.alg, sol.interp,
230-
sol.dense, tslocation,
231-
sol.stats, sol.alg_choice,
232-
sol.retcode, sol.seed)
217+
function solution_new_tslocation(sol::AbstractRODESolution, tslocation)
218+
return @set sol.tslocation = tslocation
233219
end
234220

235221
function solution_slice(sol::AbstractRODESolution{T, N}, I) where {T, N}
236-
RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors),
237-
typeof(sol.t),
238-
typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp),
239-
typeof(sol.stats), typeof(sol.alg_choice)}(sol.u[I],
240-
sol.u_analytic === nothing ?
241-
nothing : sol.u_analytic,
242-
sol.errors, sol.t[I],
243-
sol.W, sol.prob,
244-
sol.alg, sol.interp,
245-
false, sol.tslocation,
246-
sol.stats, sol.alg_choice,
247-
sol.retcode, sol.seed)
222+
@reset sol.u = sol.u[I]
223+
@reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I]
224+
@reset sol.t = sol.t[I]
225+
return @set sol.dense = false
248226
end
249227

250228
function sensitivity_solution(sol::AbstractRODESolution, u, t)
@@ -259,22 +237,7 @@ function sensitivity_solution(sol::AbstractRODESolution, u, t)
259237
end
260238

261239
interp = enable_interpolation_sensitivitymode(sol.interp)
262-
263-
RODESolution{T, N, typeof(u), typeof(sol.u_analytic),
264-
typeof(sol.errors), typeof(t),
265-
typeof(nothing), typeof(sol.prob), typeof(sol.alg),
266-
typeof(sol.interp), typeof(sol.stats), typeof(sol.alg_choice)}(u,
267-
sol.u_analytic,
268-
sol.errors,
269-
t,
270-
nothing,
271-
sol.prob,
272-
sol.alg,
273-
sol.interp,
274-
sol.dense,
275-
sol.tslocation,
276-
sol.stats,
277-
sol.alg_choice,
278-
sol.retcode,
279-
sol.seed)
240+
@reset sol.u = u
241+
@reset sol.t = t isa Vector ? t : collect(t)
242+
return @set sol.interp = interp
280243
end

0 commit comments

Comments
 (0)