Skip to content

Commit d7edbe7

Browse files
Merge pull request #645 from AayushSabharwal/as/discrete-save
feat: add discrete saving feature to ODESolution
2 parents 71a578d + fe68aaa commit d7edbe7

File tree

10 files changed

+763
-102
lines changed

10 files changed

+763
-102
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1212
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1313
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1414
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
15+
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
1516
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
1617
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -62,6 +63,7 @@ DataFrames = "1.6"
6263
Distributed = "1.10"
6364
DocStringExtensions = "0.9"
6465
EnumX = "1"
66+
Expronicon = "0.8"
6567
ForwardDiff = "0.10.36"
6668
FunctionWrappersWrappers = "0.1.3"
6769
IteratorInterfaceExtensions = "^1"
@@ -78,7 +80,7 @@ PyCall = "1.96"
7880
PythonCall = "0.9.15"
7981
RCall = "0.14.0"
8082
RecipesBase = "1.3.4"
81-
RecursiveArrayTools = "3.22.0"
83+
RecursiveArrayTools = "3.26.0"
8284
Reexport = "1"
8385
RuntimeGeneratedFunctions = "0.5.12"
8486
SciMLOperators = "0.3.7"
@@ -87,7 +89,7 @@ StableRNGs = "1.0"
8789
StaticArrays = "1.7"
8890
StaticArraysCore = "1.4"
8991
Statistics = "1.10"
90-
SymbolicIndexingInterface = "0.3.20"
92+
SymbolicIndexingInterface = "0.3.26"
9193
Tables = "1.11"
9294
Zygote = "0.6.67"
9395
julia = "1.10"

ext/SciMLBaseZygoteExt.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import SciMLStructures
3333
N = length((size(dprob.u0)..., length(du)))
3434
end
3535
Δ′ = ODESolution{T, N}(du, nothing, nothing,
36-
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
36+
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
3737
VA.alg_choice, VA.retcode)
3838
(Δ′, nothing, nothing)
3939
end
@@ -60,7 +60,7 @@ end
6060
T = eltype(eltype(VA.u))
6161
N = ndims(VA)
6262
Δ′ = ODESolution{T, N}(du, nothing, nothing,
63-
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
63+
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
6464
VA.alg_choice, VA.retcode)
6565
(Δ′, nothing, nothing)
6666
end
@@ -117,9 +117,11 @@ end
117117
elseif i === nothing
118118
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
119119
else
120-
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
121-
for (x, j) in zip(VA.u, 1:length(VA))]
122-
(Δ′, nothing)
120+
VA = recursivecopy(VA)
121+
recursivefill!(VA, zero(eltype(VA)))
122+
v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...)
123+
copyto!(v, Δ)
124+
(VA, nothing)
123125
end
124126
end
125127
VA[sym], ODESolution_getindex_pullback
@@ -172,15 +174,15 @@ end
172174
VA[sym], ODESolution_getindex_pullback
173175
end
174176

175-
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14
176-
}(u,
177+
@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13
178+
, T14, T15}(u,
177179
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
178-
T9, T10, T11, T12, T13, T14}
180+
T9, T10, T11, T12, T13, T14, T15}
179181
function ODESolutionAdjoint(ȳ)
180182
(ȳ, ntuple(_ -> nothing, length(args))...)
181183
end
182184

183-
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14}(u, args...),
185+
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u, args...),
184186
ODESolutionAdjoint
185187
end
186188

src/SciMLBase.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import RuntimeGeneratedFunctions
2323
import EnumX
2424
import ADTypes: AbstractADType
2525
import Accessors: @set, @reset
26+
using Expronicon.ADT: @match
2627

2728
using Reexport
2829
using SciMLOperators
@@ -717,6 +718,7 @@ include("problems/problem_traits.jl")
717718
include("problems/problem_interface.jl")
718719
include("problems/optimization_problems.jl")
719720

721+
include("clock.jl")
720722
include("solutions/basic_solutions.jl")
721723
include("solutions/nonlinear_solutions.jl")
722724
include("solutions/ode_solutions.jl")
@@ -835,4 +837,6 @@ export step!, deleteat!, addat!, get_tmp_cache,
835837

836838
export ContinuousCallback, DiscreteCallback, CallbackSet, VectorContinuousCallback
837839

840+
export Clocks, TimeDomain, is_discrete_time_domain, isclock, issolverstepclock, iscontinuous
841+
838842
end

src/clock.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
module Clocks
2+
3+
export TimeDomain
4+
5+
using Expronicon.ADT: @adt, @match
6+
7+
@adt TimeDomain begin
8+
Continuous
9+
struct PeriodicClock
10+
dt::Union{Nothing, Float64, Rational{Int}}
11+
phase::Float64 = 0.0
12+
end
13+
SolverStepClock
14+
end
15+
16+
Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d)
17+
18+
end
19+
20+
using .Clocks
21+
22+
"""
23+
Clock(dt)
24+
Clock()
25+
26+
The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will
27+
be inferred (if possible).
28+
"""
29+
Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase)
30+
Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase)
31+
Clock(; phase = 0.0) = PeriodicClock(nothing, phase)
32+
33+
@doc """
34+
SolverStepClock
35+
36+
A clock that ticks at each solver step (sometimes referred to as "continuous sample time").
37+
This clock **does generally not have equidistant tick intervals**, instead, the tick
38+
interval depends on the adaptive step-size selection of the continuous solver, as well as
39+
any continuous event handling. If adaptivity of the solver is turned off and there are no
40+
continuous events, the tick interval will be given by the fixed solver time step `dt`.
41+
42+
Due to possibly non-equidistant tick intervals, this clock should typically not be used with
43+
discrete-time systems that assume a fixed sample time, such as PID controllers and digital
44+
filters.
45+
""" SolverStepClock
46+
47+
isclock(c) = @match c begin
48+
PeriodicClock(_...) => true
49+
_ => false
50+
end
51+
52+
issolverstepclock(c) = @match c begin
53+
&SolverStepClock => true
54+
_ => false
55+
end
56+
57+
iscontinuous(c) = @match c begin
58+
&Continuous => true
59+
_ => false
60+
end
61+
62+
is_discrete_time_domain(c) = !iscontinuous(c)
63+
64+
function first_clock_tick_time(c, t0)
65+
@match c begin
66+
PeriodicClock(dt, _...) => ceil(t0 / dt) * dt
67+
&SolverStepClock => t0
68+
&Continuous => error("Continuous is not a discrete clock")
69+
end
70+
end
71+
72+
struct IndexedClock{I}
73+
clock::TimeDomain
74+
idx::I
75+
end
76+
77+
Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx)
78+
79+
function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution)
80+
c = ic.clock
81+
82+
return @match c begin
83+
PeriodicClock(dt, _...) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt
84+
&SolverStepClock => begin
85+
ssc_idx = findfirst(eachindex(sol.discretes)) do i
86+
!isa(sol.discretes[i].t, AbstractRange)
87+
end
88+
sol.discretes[ssc_idx].t[ic.idx]
89+
end
90+
&Continuous => sol.t[ic.idx]
91+
end
92+
end

src/remake.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,10 @@ end
475475
anydict(d) = Dict{Any, Any}(d)
476476
anydict() = Dict{Any, Any}()
477477

478+
function _updated_u0_p_internal(
479+
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
480+
return state_values(prob), parameter_values(prob)
481+
end
478482
function _updated_u0_p_internal(
479483
prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false)
480484
u0 = state_values(prob)

0 commit comments

Comments
 (0)