Skip to content

Commit c5bd315

Browse files
Merge pull request #974 from vyudu/rode
feat: add discretes to RODESolution
2 parents 3e87540 + e0e2554 commit c5bd315

File tree

4 files changed

+70
-13
lines changed

4 files changed

+70
-13
lines changed

src/remake.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ function remake(
191191
props = @delete props._func_cache
192192
props = @insert props._func_cache = forig._func_cache
193193
end
194-
194+
195195
args = (args..., f2)
196196
end
197197
end

src/scimlfunctions.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3067,7 +3067,7 @@ function ImplicitDiscreteFunction{iip, specialize}(f;
30673067
f.observed :
30683068
DEFAULT_OBSERVED,
30693069
resid_prototype = __has_resid_prototype(f) ?
3070-
f.resid_prototype :
3070+
f.resid_prototype :
30713071
nothing,
30723072
sys = __has_sys(f) ? f.sys : nothing,
30733073
initialization_data = __has_initialization_data(f) ? f.initialization_data :
@@ -3083,11 +3083,12 @@ function ImplicitDiscreteFunction{iip, specialize}(f;
30833083
analytic,
30843084
observed,
30853085
sys,
3086-
resid_prototype,
3086+
resid_prototype,
30873087
initialization_data)
30883088
else
30893089
ImplicitDiscreteFunction{
3090-
iip, specialize, typeof(_f), typeof(analytic), typeof(observed), typeof(sys), typeof(resid_prototype),
3090+
iip, specialize, typeof(_f), typeof(analytic),
3091+
typeof(observed), typeof(sys), typeof(resid_prototype),
30913092
typeof(initialization_data)}(
30923093
_f, analytic, observed, sys, resid_prototype, initialization_data)
30933094
end
@@ -3097,7 +3098,9 @@ function ImplicitDiscreteFunction{iip}(f; kwargs...) where {iip}
30973098
ImplicitDiscreteFunction{iip, FullSpecialize}(f; kwargs...)
30983099
end
30993100
ImplicitDiscreteFunction{iip}(f::ImplicitDiscreteFunction; kwargs...) where {iip} = f
3100-
function ImplicitDiscreteFunction(f; resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing, kwargs...)
3101+
function ImplicitDiscreteFunction(
3102+
f; resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing,
3103+
kwargs...)
31013104
ImplicitDiscreteFunction{isinplace(f, 5), FullSpecialize}(f; resid_prototype, kwargs...)
31023105
end
31033106
ImplicitDiscreteFunction(f::ImplicitDiscreteFunction; kwargs...) = f
@@ -3107,11 +3110,13 @@ function unwrapped_f(f::ImplicitDiscreteFunction, newf = unwrapped_f(f.f))
31073110

31083111
if specialize === NoSpecialize
31093112
ImplicitDiscreteFunction{isinplace(f, 5), specialize, Any, Any, Any,
3110-
Any, Any, Any}(newf, f.analytic, f.observed, f.sys, f.resid_prototype, f.initialization_data)
3113+
Any, Any, Any}(
3114+
newf, f.analytic, f.observed, f.sys, f.resid_prototype, f.initialization_data)
31113115
else
31123116
ImplicitDiscreteFunction{isinplace(f, 5), specialize, typeof(newf),
31133117
typeof(f.analytic),
3114-
typeof(f.observed), typeof(f.sys), typeof(resid_prototype), typeof(f.initialization_data)}(newf,
3118+
typeof(f.observed), typeof(f.sys), typeof(resid_prototype), typeof(f.initialization_data)}(
3119+
newf,
31153120
f.analytic, f.observed, f.sys, f.resid_prototype, f.initialization_data)
31163121
end
31173122
end

src/solutions/rode_solutions.jl

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,15 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
3232
exited due to an error. For more details, see
3333
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
3434
"""
35-
struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, S,
35+
struct RODESolution{T, N, uType, uType2, DType, tType, randType, discType, P, A, IType, S,
3636
AC <: Union{Nothing, Vector{Int}}, V} <:
3737
AbstractRODESolution{T, N, uType}
3838
u::uType
3939
u_analytic::uType2
4040
errors::DType
4141
t::tType
4242
W::randType
43+
discretes::discType
4344
prob::P
4445
alg::A
4546
interp::IType
@@ -63,9 +64,10 @@ function ConstructionBase.setproperties(sol::RODESolution, patch::NamedTuple)
6364
patch = merge(getproperties(sol), patch)
6465
return RODESolution{
6566
T, N, typeof(patch.u), typeof(patch.u_analytic), typeof(patch.errors),
66-
typeof(patch.t), typeof(patch.W), typeof(patch.prob), typeof(patch.alg), typeof(patch.interp),
67+
typeof(patch.t), typeof(patch.W), typeof(patch.discretes),
68+
typeof(patch.prob), typeof(patch.alg), typeof(patch.interp),
6769
typeof(patch.stats), typeof(patch.alg_choice), typeof(patch.saved_subsystem)}(
68-
patch.u, patch.u_analytic, patch.errors, patch.t, patch.W,
70+
patch.u, patch.u_analytic, patch.errors, patch.t, patch.W, patch.discretes,
6971
patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats,
7072
patch.alg_choice, patch.retcode, patch.seed, patch.saved_subsystem)
7173
end
@@ -120,16 +122,28 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
120122
Base.depwarn(msg, :build_solution)
121123
end
122124

125+
ps = parameter_values(prob)
126+
if has_sys(prob.f)
127+
sswf = if saved_subsystem === nothing
128+
prob.f.sys
129+
else
130+
SavedSubsystemWithFallback(saved_subsystem, prob.f.sys)
131+
end
132+
discretes = create_parameter_timeseries_collection(sswf, ps, prob.tspan)
133+
else
134+
discretes = nothing
135+
end
123136
if has_analytic(f)
124137
u_analytic = Vector{typeof(prob.u0)}()
125138
errors = Dict{Symbol, real(eltype(prob.u0))}()
126139
sol = RODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
127-
typeof(W),
140+
typeof(W), typeof(discretes),
128141
typeof(prob), typeof(alg), typeof(interp), typeof(stats),
129142
typeof(alg_choice), typeof(saved_subsystem)}(u,
130143
u_analytic,
131144
errors,
132145
t, W,
146+
discretes,
133147
prob,
134148
alg,
135149
interp,
@@ -149,15 +163,37 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem},
149163
return sol
150164
else
151165
return RODESolution{T, N, typeof(u), Nothing, Nothing, typeof(t),
152-
typeof(W), typeof(prob), typeof(alg), typeof(interp),
166+
typeof(W), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
153167
typeof(stats), typeof(alg_choice), typeof(saved_subsystem)}(
154-
u, nothing, nothing, t, W,
168+
u, nothing, nothing, t, W, discretes,
155169
prob, alg, interp,
156170
dense, 0, stats,
157171
alg_choice, retcode, seed, saved_subsystem)
158172
end
159173
end
160174

175+
function save_discretes!(sol::AbstractRODESolution, t, vals, timeseries_idx)
176+
RecursiveArrayTools.has_discretes(sol) || return
177+
disc = RecursiveArrayTools.get_discretes(sol)
178+
_save_discretes_internal!(disc[timeseries_idx], t, vals)
179+
end
180+
181+
function get_interpolated_discretes(sol::AbstractRODESolution, t, deriv, continuity)
182+
is_parameter_timeseries(sol) == Timeseries() || return nothing
183+
184+
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
185+
interp_discs = map(discs) do partition
186+
hold_discrete(partition.u, partition.t, t)
187+
end
188+
return ParameterTimeseriesCollection(interp_discs, parameter_values(discs))
189+
end
190+
191+
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
192+
T1, T2, T3, T4, T5, T6, T7,
193+
S <: RODESolution{T1, T2, T3, T4, T5, T6, T7, <:ParameterTimeseriesCollection}}
194+
Timeseries()
195+
end
196+
161197
function calculate_solution_errors!(sol::AbstractRODESolution; fill_uanalytic = true,
162198
timeseries_errors = true, dense_errors = true)
163199
if sol.prob.f isa Tuple

test/downstream/comprehensive_indexing.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,3 +977,19 @@ end
977977
sol = solve(prob, ImplicitEM())
978978
@test sol[sym] sol(sol.t .- sol.ps[delay]; idxs = original)
979979
end
980+
981+
@testset "RODESolutions save discretes" begin
982+
@parameters k(t)
983+
@variables A(t)
984+
function affect2!(integ, u, p, ctx)
985+
integ.ps[p.k] += 1.0
986+
end
987+
db = 1.0 => (affect2!, [], [k], [k], nothing)
988+
989+
@named ssys = SDESystem(D(A) ~ k * A, [0.0], t, [A], [k], discrete_events = db)
990+
ssys = complete(ssys)
991+
prob = SDEProblem(ssys, [A => 1.0], (0.0, 4.0), [k => 1.0])
992+
sol = solve(prob, RI5())
993+
@test sol[k] isa AbstractVector
994+
@test sol[k] == [1.0, 2.0, 3.0, 4.0]
995+
end

0 commit comments

Comments
 (0)