Skip to content

Commit 67e285b

Browse files
Merge pull request #99 from SciML/as/ddes
feat: support observed functions for history-dependent systems
2 parents 8cf2450 + 5d5df12 commit 67e285b

File tree

6 files changed

+113
-7
lines changed

6 files changed

+113
-7
lines changed

docs/src/api.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ observed
3636
parameter_observed
3737
```
3838

39+
#### Historical index providers
40+
41+
```@docs
42+
is_markovian
43+
```
44+
3945
#### Parameter timeseries
4046

4147
If the index provider contains parameters that change during the course of the simulation
@@ -67,6 +73,12 @@ getu
6773
setu
6874
```
6975

76+
#### Historical value providers
77+
78+
```@docs
79+
get_history_function
80+
```
81+
7082
### Parameter indexing
7183

7284
```@docs

src/SymbolicIndexingInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in
1616
parameter_symbols, is_independent_variable, independent_variable_symbols,
1717
is_observed, observed, parameter_observed,
1818
ContinuousTimeseries, get_all_timeseries_indexes,
19-
is_time_dependent, constant_structure, symbolic_container,
19+
is_time_dependent, is_markovian, constant_structure, symbolic_container,
2020
all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values,
2121
symbolic_evaluate
2222
include("index_provider_interface.jl")
@@ -26,7 +26,7 @@ include("symbol_cache.jl")
2626

2727
export parameter_values, set_parameter!, finalize_parameters_hook!,
2828
get_parameter_timeseries_collection, with_updated_parameter_timeseries_values,
29-
state_values, set_state!, current_time
29+
state_values, set_state!, current_time, get_history_function
3030
include("value_provider_interface.jl")
3131

3232
export ParameterTimeseriesCollection

src/index_provider_interface.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,13 @@ the order of states or a time index, which identifies the order of states. This
201201
does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus,
202202
it is mandatory to always check `is_observed` before using this function.
203203
204-
See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref)
204+
If `!is_markovian(indp)`, the returned function must have the signature
205+
`(u, h, p, t) -> [values...]` where `h` is the history function, which can be called
206+
to obtain past values of the state. The exact signature and semantics of `h` depend
207+
on how it is used inside the returned function. `h` is obtained from a value
208+
provider using [`get_history_function`](@ref).
209+
210+
See also: [`is_time_dependent`](@ref), [`is_markovian`](@ref), [`constant_structure`](@ref).
205211
"""
206212
observed(indp, sym) = observed(symbolic_container(indp), sym)
207213
observed(indp, sym, states) = observed(symbolic_container(indp), sym, states)
@@ -213,6 +219,27 @@ Check if `indp` has time as (one of) its independent variables.
213219
"""
214220
is_time_dependent(indp) = is_time_dependent(symbolic_container(indp))
215221

222+
"""
223+
is_markovian(indp)
224+
225+
Check if an index provider represents a Markovian system. Markovian systems do not require
226+
knowledge of past states to simulate. This function is only applicable to
227+
index providers for which `is_time_dependent(indp)` returns `true`.
228+
229+
Non-Markovian index providers return [`observed`](@ref) functions with a different signature.
230+
All value providers associated with a non-markovian index provider must implement
231+
[`get_history_function`](@ref).
232+
233+
Returns `true` by default.
234+
"""
235+
function is_markovian(indp)
236+
if hasmethod(symbolic_container, Tuple{typeof(indp)})
237+
is_markovian(symbolic_container(indp))
238+
else
239+
true
240+
end
241+
end
242+
216243
"""
217244
constant_structure(indp)
218245

src/state_indexing.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,17 @@ struct GetIndepvar <: AbstractStateGetIndexer end
5656
(::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob)
5757
(::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i)
5858

59-
struct TimeDependentObservedFunction{I, F} <: AbstractStateGetIndexer
59+
struct TimeDependentObservedFunction{I, F, H} <: AbstractStateGetIndexer
6060
ts_idxs::I
6161
obsfn::F
6262
end
6363

64+
function TimeDependentObservedFunction{H}(ts_idxs, obsfn) where {H}
65+
return TimeDependentObservedFunction{typeof(ts_idxs), typeof(obsfn), H}(ts_idxs, obsfn)
66+
end
67+
68+
const NonMarkovianObservedFunction = TimeDependentObservedFunction{I, F, false} where {I, F}
69+
6470
indexer_timeseries_index(t::TimeDependentObservedFunction) = t.ts_idxs
6571
function is_indexer_timeseries(::Type{G}) where {G <:
6672
TimeDependentObservedFunction{ContinuousTimeseries}}
@@ -74,15 +80,26 @@ function (o::TimeDependentObservedFunction)(ts::IsTimeseriesTrait, prob, args...
7480
return o(ts, is_indexer_timeseries(o), prob, args...)
7581
end
7682

77-
function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob)
83+
function (o::TimeDependentObservedFunction)(::Timeseries, ::IndexerBoth, prob)
84+
return o.obsfn.(state_values(prob),
85+
(parameter_values(prob),),
86+
current_time(prob))
87+
end
88+
function (o::NonMarkovianObservedFunction)(::Timeseries, ::IndexerBoth, prob)
7889
return o.obsfn.(state_values(prob),
90+
(get_history_function(prob),),
7991
(parameter_values(prob),),
8092
current_time(prob))
8193
end
8294
function (o::TimeDependentObservedFunction)(
8395
::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex})
8496
return o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
8597
end
98+
function (o::NonMarkovianObservedFunction)(
99+
::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex})
100+
return o.obsfn(state_values(prob, i), get_history_function(prob),
101+
parameter_values(prob), current_time(prob, i))
102+
end
86103
function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob, ::Colon)
87104
return o(ts, prob)
88105
end
@@ -98,6 +115,10 @@ end
98115
function (o::TimeDependentObservedFunction)(::NotTimeseries, ::IndexerBoth, prob)
99116
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
100117
end
118+
function (o::NonMarkovianObservedFunction)(::NotTimeseries, ::IndexerBoth, prob)
119+
return o.obsfn(state_values(prob), get_history_function(prob),
120+
parameter_values(prob), current_time(prob))
121+
end
101122

102123
function (o::TimeDependentObservedFunction)(
103124
::Timeseries, ::IndexerMixedTimeseries, prob, args...)
@@ -107,6 +128,11 @@ function (o::TimeDependentObservedFunction)(
107128
::NotTimeseries, ::IndexerMixedTimeseries, prob, args...)
108129
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
109130
end
131+
function (o::NonMarkovianObservedFunction)(
132+
::NotTimeseries, ::IndexerMixedTimeseries, prob, args...)
133+
return o.obsfn(state_values(prob), get_history_function(prob),
134+
parameter_values(prob), current_time(prob))
135+
end
110136

111137
struct TimeIndependentObservedFunction{F} <: AbstractStateGetIndexer
112138
obsfn::F
@@ -137,7 +163,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
137163
ts_idxs = collect(ts_idxs)
138164
end
139165
fn = observed(sys, sym)
140-
return TimeDependentObservedFunction(ts_idxs, fn)
166+
return TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, fn)
141167
else
142168
return getp(sys, sym)
143169
end
@@ -256,7 +282,7 @@ for (t1, t2) in [
256282
else
257283
obs = observed(sys, sym_arr)
258284
getter = if is_time_dependent(sys)
259-
TimeDependentObservedFunction(ts_idxs, obs)
285+
TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, obs)
260286
else
261287
TimeIndependentObservedFunction(obs)
262288
end

src/value_provider_interface.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,16 @@ current_time(arr::AbstractVector) = arr
142142
current_time(valp, i) = current_time(valp)[i]
143143
current_time(valp, ::Colon) = current_time(valp)
144144

145+
"""
146+
get_history_function(valp)
147+
148+
Return the history function for a value provider. This is required for all value providers
149+
associated with an index provider `indp` for which `!is_markovian(indp)`.
150+
151+
See also: [`is_markovian`](@ref).
152+
"""
153+
function get_history_function end
154+
145155
###########
146156
# Utilities
147157
###########

test/state_indexing_test.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,34 @@ for (sym, val, check_inference) in [
291291
end
292292
@test getter(fs) == val
293293
end
294+
295+
struct NonMarkovianWrapper{S <: SymbolCache}
296+
sys::S
297+
end
298+
299+
SymbolicIndexingInterface.symbolic_container(hw::NonMarkovianWrapper) = hw.sys
300+
SymbolicIndexingInterface.is_markovian(::NonMarkovianWrapper) = false
301+
function SymbolicIndexingInterface.observed(hw::NonMarkovianWrapper, sym)
302+
let inner = observed(hw.sys, sym)
303+
fn(u, h, p, t) = inner(u .+ h(t - 0.1), p, t)
304+
end
305+
end
306+
function SymbolicIndexingInterface.get_history_function(fs::FakeSolution)
307+
t -> t .* ones(length(fs.u[1]))
308+
end
309+
function SymbolicIndexingInterface.get_history_function(fi::FakeIntegrator)
310+
t -> t .* ones(length(fi.u))
311+
end
312+
313+
sys = NonMarkovianWrapper(SymbolCache([:x, :y, :z], [:a, :b, :c], :t))
314+
u0 = [1.0, 2.0, 3.0]
315+
u = [u0 .* i for i in 1:11]
316+
p = [10.0, 20.0, 30.0]
317+
ts = 0.0:0.1:1.0
318+
319+
fi = FakeIntegrator(sys, u0, p, ts[1])
320+
fs = FakeSolution(sys, u, p, ts)
321+
getter = getu(sys, :(x + y))
322+
@test getter(fi) 2.8
323+
@test getter(fs) [3.0i + 2(ts[i] - 0.1) for i in 1:11]
324+
@test getter(fs, 1) 2.8

0 commit comments

Comments
 (0)