Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ observed
parameter_observed
```

#### Historical index providers

```@docs
is_markovian
```

#### Parameter timeseries

If the index provider contains parameters that change during the course of the simulation
Expand Down Expand Up @@ -67,6 +73,12 @@ getu
setu
```

#### Historical value providers

```@docs
get_history_function
```

### Parameter indexing

```@docs
Expand Down
4 changes: 2 additions & 2 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in
parameter_symbols, is_independent_variable, independent_variable_symbols,
is_observed, observed, parameter_observed,
ContinuousTimeseries, get_all_timeseries_indexes,
is_time_dependent, constant_structure, symbolic_container,
is_time_dependent, is_markovian, constant_structure, symbolic_container,
all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values,
symbolic_evaluate
include("index_provider_interface.jl")
Expand All @@ -26,7 +26,7 @@ include("symbol_cache.jl")

export parameter_values, set_parameter!, finalize_parameters_hook!,
get_parameter_timeseries_collection, with_updated_parameter_timeseries_values,
state_values, set_state!, current_time
state_values, set_state!, current_time, get_history_function
include("value_provider_interface.jl")

export ParameterTimeseriesCollection
Expand Down
29 changes: 28 additions & 1 deletion src/index_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,13 @@ the order of states or a time index, which identifies the order of states. This
does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus,
it is mandatory to always check `is_observed` before using this function.

See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref)
If `!is_markovian(indp)`, the returned function must have the signature
`(u, h, p, t) -> [values...]` where `h` is the history function, which can be called
to obtain past values of the state. The exact signature and semantics of `h` depend
on how it is used inside the returned function. `h` is obtained from a value
provider using [`get_history_function`](@ref).

See also: [`is_time_dependent`](@ref), [`is_markovian`](@ref), [`constant_structure`](@ref).
"""
observed(indp, sym) = observed(symbolic_container(indp), sym)
observed(indp, sym, states) = observed(symbolic_container(indp), sym, states)
Expand All @@ -213,6 +219,27 @@ Check if `indp` has time as (one of) its independent variables.
"""
is_time_dependent(indp) = is_time_dependent(symbolic_container(indp))

"""
is_markovian(indp)

Check if an index provider represents a Markovian system. Markovian systems do not require
knowledge of past states to simulate. This function is only applicable to
index providers for which `is_time_dependent(indp)` returns `true`.

Non-Markovian index providers return [`observed`](@ref) functions with a different signature.
All value providers associated with a non-markovian index provider must implement
[`get_history_function`](@ref).

Returns `true` by default.
"""
function is_markovian(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)})
is_markovian(symbolic_container(indp))
else
true
end
end

"""
constant_structure(indp)

Expand Down
34 changes: 30 additions & 4 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,17 @@ struct GetIndepvar <: AbstractStateGetIndexer end
(::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob)
(::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i)

struct TimeDependentObservedFunction{I, F} <: AbstractStateGetIndexer
struct TimeDependentObservedFunction{I, F, H} <: AbstractStateGetIndexer
ts_idxs::I
obsfn::F
end

function TimeDependentObservedFunction{H}(ts_idxs, obsfn) where {H}
return TimeDependentObservedFunction{typeof(ts_idxs), typeof(obsfn), H}(ts_idxs, obsfn)
end

const NonMarkovianObservedFunction = TimeDependentObservedFunction{I, F, false} where {I, F}

indexer_timeseries_index(t::TimeDependentObservedFunction) = t.ts_idxs
function is_indexer_timeseries(::Type{G}) where {G <:
TimeDependentObservedFunction{ContinuousTimeseries}}
Expand All @@ -74,15 +80,26 @@ function (o::TimeDependentObservedFunction)(ts::IsTimeseriesTrait, prob, args...
return o(ts, is_indexer_timeseries(o), prob, args...)
end

function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob)
function (o::TimeDependentObservedFunction)(::Timeseries, ::IndexerBoth, prob)
return o.obsfn.(state_values(prob),
(parameter_values(prob),),
current_time(prob))
end
function (o::NonMarkovianObservedFunction)(::Timeseries, ::IndexerBoth, prob)
return o.obsfn.(state_values(prob),
(get_history_function(prob),),
(parameter_values(prob),),
current_time(prob))
end
function (o::TimeDependentObservedFunction)(
::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex})
return o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
end
function (o::NonMarkovianObservedFunction)(
::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex})
return o.obsfn(state_values(prob, i), get_history_function(prob),
parameter_values(prob), current_time(prob, i))
end
function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob, ::Colon)
return o(ts, prob)
end
Expand All @@ -98,6 +115,10 @@ end
function (o::TimeDependentObservedFunction)(::NotTimeseries, ::IndexerBoth, prob)
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
end
function (o::NonMarkovianObservedFunction)(::NotTimeseries, ::IndexerBoth, prob)
return o.obsfn(state_values(prob), get_history_function(prob),
parameter_values(prob), current_time(prob))
end

function (o::TimeDependentObservedFunction)(
::Timeseries, ::IndexerMixedTimeseries, prob, args...)
Expand All @@ -107,6 +128,11 @@ function (o::TimeDependentObservedFunction)(
::NotTimeseries, ::IndexerMixedTimeseries, prob, args...)
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
end
function (o::NonMarkovianObservedFunction)(
::NotTimeseries, ::IndexerMixedTimeseries, prob, args...)
return o.obsfn(state_values(prob), get_history_function(prob),
parameter_values(prob), current_time(prob))
end

struct TimeIndependentObservedFunction{F} <: AbstractStateGetIndexer
obsfn::F
Expand Down Expand Up @@ -137,7 +163,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
ts_idxs = collect(ts_idxs)
end
fn = observed(sys, sym)
return TimeDependentObservedFunction(ts_idxs, fn)
return TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, fn)
else
return getp(sys, sym)
end
Expand Down Expand Up @@ -256,7 +282,7 @@ for (t1, t2) in [
else
obs = observed(sys, sym_arr)
getter = if is_time_dependent(sys)
TimeDependentObservedFunction(ts_idxs, obs)
TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, obs)
else
TimeIndependentObservedFunction(obs)
end
Expand Down
10 changes: 10 additions & 0 deletions src/value_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ current_time(arr::AbstractVector) = arr
current_time(valp, i) = current_time(valp)[i]
current_time(valp, ::Colon) = current_time(valp)

"""
get_history_function(valp)

Return the history function for a value provider. This is required for all value providers
associated with an index provider `indp` for which `!is_markovian(indp)`.

See also: [`is_markovian`](@ref).
"""
function get_history_function end

###########
# Utilities
###########
Expand Down
31 changes: 31 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,34 @@ for (sym, val, check_inference) in [
end
@test getter(fs) == val
end

struct NonMarkovianWrapper{S <: SymbolCache}
sys::S
end

SymbolicIndexingInterface.symbolic_container(hw::NonMarkovianWrapper) = hw.sys
SymbolicIndexingInterface.is_markovian(::NonMarkovianWrapper) = false
function SymbolicIndexingInterface.observed(hw::NonMarkovianWrapper, sym)
let inner = observed(hw.sys, sym)
fn(u, h, p, t) = inner(u .+ h(t - 0.1), p, t)
end
end
function SymbolicIndexingInterface.get_history_function(fs::FakeSolution)
t -> t .* ones(length(fs.u[1]))
end
function SymbolicIndexingInterface.get_history_function(fi::FakeIntegrator)
t -> t .* ones(length(fi.u))
end

sys = NonMarkovianWrapper(SymbolCache([:x, :y, :z], [:a, :b, :c], :t))
u0 = [1.0, 2.0, 3.0]
u = [u0 .* i for i in 1:11]
p = [10.0, 20.0, 30.0]
ts = 0.0:0.1:1.0

fi = FakeIntegrator(sys, u0, p, ts[1])
fs = FakeSolution(sys, u, p, ts)
getter = getu(sys, :(x + y))
@test getter(fi) ≈ 2.8
@test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11]
@test getter(fs, 1) ≈ 2.8
Loading