Skip to content

Commit 28a1ebe

Browse files
committed
feat: add discretes to RODESolution
1 parent ba51e90 commit 28a1ebe

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

src/solutions/rode_solutions.jl

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
3232
exited due to an error. For more details, see
3333
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
3434
"""
35-
struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, S,
35+
struct RODESolution{T, N, uType, uType2, DType, tType, randType, discType, P, A, IType, S,
3636
AC <: Union{Nothing, Vector{Int}}, V} <:
3737
AbstractRODESolution{T, N, uType}
3838
u::uType
3939
u_analytic::uType2
4040
errors::DType
4141
t::tType
4242
W::randType
43+
discretes::discType
4344
prob::P
4445
alg::A
4546
interp::IType
@@ -63,9 +64,9 @@ function ConstructionBase.setproperties(sol::RODESolution, patch::NamedTuple)
6364
patch = merge(getproperties(sol), patch)
6465
return RODESolution{
6566
T, N, typeof(patch.u), typeof(patch.u_analytic), typeof(patch.errors),
66-
typeof(patch.t), typeof(patch.W), typeof(patch.prob), typeof(patch.alg), typeof(patch.interp),
67+
typeof(patch.t), typeof(patch.W), typeof(patch.discretes), typeof(patch.prob), typeof(patch.alg), typeof(patch.interp),
6768
typeof(patch.stats), typeof(patch.alg_choice), typeof(patch.saved_subsystem)}(
68-
patch.u, patch.u_analytic, patch.errors, patch.t, patch.W,
69+
patch.u, patch.u_analytic, patch.errors, patch.t, patch.W, patch.discretes,
6970
patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
7071
patch.alg_choice, patch.retcode, patch.seed, patch.saved_subsystem)
7172
end
@@ -120,11 +121,23 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
120121
Base.depwarn(msg, :build_solution)
121122
end
122123

124+
ps = parameter_values(prob)
125+
if has_sys(prob.f)
126+
sswf = if saved_subsystem === nothing
127+
prob.f.sys
128+
else
129+
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
130+
end
131+
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
132+
else
133+
discretes = nothing
134+
end
135+
@show discretes
123136
if has_analytic(f)
124137
u_analytic = Vector{typeof(prob.u0)}()
125138
errors = Dict{Symbol, real(eltype(prob.u0))}()
126139
sol = RODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
127-
typeof(W),
140+
typeof(W), typeof(discretes),
128141
typeof(prob), typeof(alg), typeof(interp), typeof(stats),
129142
typeof(alg_choice), typeof(saved_subsystem)}(u,
130143
u_analytic,
@@ -149,15 +162,37 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
149162
return sol
150163
else
151164
return RODESolution{T, N, typeof(u), Nothing, Nothing, typeof(t),
152-
typeof(W), typeof(prob), typeof(alg), typeof(interp),
165+
typeof(W), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
153166
typeof(stats), typeof(alg_choice), typeof(saved_subsystem)}(
154-
u, nothing, nothing, t, W,
167+
u, nothing, nothing, t, W, discretes,
155168
prob, alg, interp,
156169
dense, 0, stats,
157170
alg_choice, retcode, seed, saved_subsystem)
158171
end
159172
end
160173

174+
function save_discretes!(sol::AbstractRODESolution, t, vals, timeseries_idx)
175+
RecursiveArrayTools.has_discretes(sol) || return
176+
disc = RecursiveArrayTools.get_discretes(sol)
177+
_save_discretes_internal!(disc[timeseries_idx], t, vals)
178+
end
179+
180+
function get_interpolated_discretes(sol::AbstractRODESolution, t, deriv, continuity)
181+
is_parameter_timeseries(sol) == Timeseries() || return nothing
182+
183+
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
184+
interp_discs = map(discs) do partition
185+
hold_discrete(partition.u, partition.t, t)
186+
end
187+
return ParameterTimeseriesCollection(interp_discs, parameter_values(discs))
188+
end
189+
190+
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
191+
T1, T2, T3, T4, T5, T6, T7,
192+
S <: RODESolution{T1, T2, T3, T4, T5, T6, T7, <: ParameterTimeseriesCollection}}
193+
Timeseries()
194+
end
195+
161196
function calculate_solution_errors!(sol::AbstractRODESolution; fill_uanalytic = true,
162197
timeseries_errors = true, dense_errors = true)
163198
if sol.prob.f isa Tuple

test/downstream/comprehensive_indexing.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,3 +977,19 @@ end
977977
sol = solve(prob, ImplicitEM())
978978
@test sol[sym] sol(sol.t .- sol.ps[delay]; idxs = original)
979979
end
980+
981+
@testset "RODESolutions save discretes" begin
982+
@parameters k(t)
983+
@variables A(t)
984+
function affect2!(integ, u, p, ctx)
985+
integ.ps[p.k] += 1.
986+
end
987+
db = 1. => (affect2!, [], [k], [k], nothing)
988+
989+
@named ssys = SDESystem(D(A) ~ k * A, [0.0], t, [A], [k], discrete_events = db)
990+
ssys = complete(ssys)
991+
prob = SDEProblem(ssys, [A => 1.], (0., 4.), [k => 1.])
992+
sol = solve(prob, RI5())
993+
@test sol[k] isa AbstractVector
994+
@test sol[k] == [1., 2., 3., 4.]
995+
end

0 commit comments

Comments
 (0)