diff --git a/docs/src/api.md b/docs/src/api.md index dfaadfb..21ba613 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -36,6 +36,12 @@ observed parameter_observed ``` +#### Historical index providers + +```@docs +is_markovian +``` + #### Parameter timeseries If the index provider contains parameters that change during the course of the simulation @@ -67,6 +73,12 @@ getu setu ``` +#### Historical value providers + +```@docs +get_history_function +``` + ### Parameter indexing ```@docs diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index a368dc0..aa89975 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -16,7 +16,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in parameter_symbols, is_independent_variable, independent_variable_symbols, is_observed, observed, parameter_observed, ContinuousTimeseries, get_all_timeseries_indexes, - is_time_dependent, constant_structure, symbolic_container, + is_time_dependent, is_markovian, constant_structure, symbolic_container, all_variable_symbols, all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate include("index_provider_interface.jl") @@ -26,7 +26,7 @@ include("symbol_cache.jl") export parameter_values, set_parameter!, finalize_parameters_hook!, get_parameter_timeseries_collection, with_updated_parameter_timeseries_values, - state_values, set_state!, current_time + state_values, set_state!, current_time, get_history_function include("value_provider_interface.jl") export ParameterTimeseriesCollection diff --git a/src/index_provider_interface.jl b/src/index_provider_interface.jl index 593f271..f2932d5 100644 --- a/src/index_provider_interface.jl +++ b/src/index_provider_interface.jl @@ -201,7 +201,13 @@ the order of states or a time index, which identifies the order of states. This does not need to be defined if [`is_observed`](@ref) always returns `false`. Thus, it is mandatory to always check `is_observed` before using this function. -See also: [`is_time_dependent`](@ref), [`constant_structure`](@ref) +If `!is_markovian(indp)`, the returned function must have the signature +`(u, h, p, t) -> [values...]` where `h` is the history function, which can be called +to obtain past values of the state. The exact signature and semantics of `h` depend +on how it is used inside the returned function. `h` is obtained from a value +provider using [`get_history_function`](@ref). + +See also: [`is_time_dependent`](@ref), [`is_markovian`](@ref), [`constant_structure`](@ref). """ observed(indp, sym) = observed(symbolic_container(indp), sym) observed(indp, sym, states) = observed(symbolic_container(indp), sym, states) @@ -213,6 +219,27 @@ Check if `indp` has time as (one of) its independent variables. """ is_time_dependent(indp) = is_time_dependent(symbolic_container(indp)) +""" + is_markovian(indp) + +Check if an index provider represents a Markovian system. Markovian systems do not require +knowledge of past states to simulate. This function is only applicable to +index providers for which `is_time_dependent(indp)` returns `true`. + +Non-Markovian index providers return [`observed`](@ref) functions with a different signature. +All value providers associated with a non-markovian index provider must implement +[`get_history_function`](@ref). + +Returns `true` by default. +""" +function is_markovian(indp) + if hasmethod(symbolic_container, Tuple{typeof(indp)}) + is_markovian(symbolic_container(indp)) + else + true + end +end + """ constant_structure(indp) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index af51766..a9a6b75 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -56,11 +56,17 @@ struct GetIndepvar <: AbstractStateGetIndexer end (::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob) (::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i) -struct TimeDependentObservedFunction{I, F} <: AbstractStateGetIndexer +struct TimeDependentObservedFunction{I, F, H} <: AbstractStateGetIndexer ts_idxs::I obsfn::F end +function TimeDependentObservedFunction{H}(ts_idxs, obsfn) where {H} + return TimeDependentObservedFunction{typeof(ts_idxs), typeof(obsfn), H}(ts_idxs, obsfn) +end + +const NonMarkovianObservedFunction = TimeDependentObservedFunction{I, F, false} where {I, F} + indexer_timeseries_index(t::TimeDependentObservedFunction) = t.ts_idxs function is_indexer_timeseries(::Type{G}) where {G <: TimeDependentObservedFunction{ContinuousTimeseries}} @@ -74,8 +80,14 @@ function (o::TimeDependentObservedFunction)(ts::IsTimeseriesTrait, prob, args... return o(ts, is_indexer_timeseries(o), prob, args...) end -function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob) +function (o::TimeDependentObservedFunction)(::Timeseries, ::IndexerBoth, prob) + return o.obsfn.(state_values(prob), + (parameter_values(prob),), + current_time(prob)) +end +function (o::NonMarkovianObservedFunction)(::Timeseries, ::IndexerBoth, prob) return o.obsfn.(state_values(prob), + (get_history_function(prob),), (parameter_values(prob),), current_time(prob)) end @@ -83,6 +95,11 @@ function (o::TimeDependentObservedFunction)( ::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex}) return o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i)) end +function (o::NonMarkovianObservedFunction)( + ::Timeseries, ::IndexerBoth, prob, i::Union{Int, CartesianIndex}) + return o.obsfn(state_values(prob, i), get_history_function(prob), + parameter_values(prob), current_time(prob, i)) +end function (o::TimeDependentObservedFunction)(ts::Timeseries, ::IndexerBoth, prob, ::Colon) return o(ts, prob) end @@ -98,6 +115,10 @@ end function (o::TimeDependentObservedFunction)(::NotTimeseries, ::IndexerBoth, prob) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) end +function (o::NonMarkovianObservedFunction)(::NotTimeseries, ::IndexerBoth, prob) + return o.obsfn(state_values(prob), get_history_function(prob), + parameter_values(prob), current_time(prob)) +end function (o::TimeDependentObservedFunction)( ::Timeseries, ::IndexerMixedTimeseries, prob, args...) @@ -107,6 +128,11 @@ function (o::TimeDependentObservedFunction)( ::NotTimeseries, ::IndexerMixedTimeseries, prob, args...) return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob)) end +function (o::NonMarkovianObservedFunction)( + ::NotTimeseries, ::IndexerMixedTimeseries, prob, args...) + return o.obsfn(state_values(prob), get_history_function(prob), + parameter_values(prob), current_time(prob)) +end struct TimeIndependentObservedFunction{F} <: AbstractStateGetIndexer obsfn::F @@ -137,7 +163,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) ts_idxs = collect(ts_idxs) end fn = observed(sys, sym) - return TimeDependentObservedFunction(ts_idxs, fn) + return TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, fn) else return getp(sys, sym) end @@ -256,7 +282,7 @@ for (t1, t2) in [ else obs = observed(sys, sym_arr) getter = if is_time_dependent(sys) - TimeDependentObservedFunction(ts_idxs, obs) + TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, obs) else TimeIndependentObservedFunction(obs) end diff --git a/src/value_provider_interface.jl b/src/value_provider_interface.jl index 7903017..1ae30b5 100644 --- a/src/value_provider_interface.jl +++ b/src/value_provider_interface.jl @@ -142,6 +142,16 @@ current_time(arr::AbstractVector) = arr current_time(valp, i) = current_time(valp)[i] current_time(valp, ::Colon) = current_time(valp) +""" + get_history_function(valp) + +Return the history function for a value provider. This is required for all value providers +associated with an index provider `indp` for which `!is_markovian(indp)`. + +See also: [`is_markovian`](@ref). +""" +function get_history_function end + ########### # Utilities ########### diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index c2b90b6..8e59779 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -291,3 +291,34 @@ for (sym, val, check_inference) in [ end @test getter(fs) == val end + +struct NonMarkovianWrapper{S <: SymbolCache} + sys::S +end + +SymbolicIndexingInterface.symbolic_container(hw::NonMarkovianWrapper) = hw.sys +SymbolicIndexingInterface.is_markovian(::NonMarkovianWrapper) = false +function SymbolicIndexingInterface.observed(hw::NonMarkovianWrapper, sym) + let inner = observed(hw.sys, sym) + fn(u, h, p, t) = inner(u .+ h(t - 0.1), p, t) + end +end +function SymbolicIndexingInterface.get_history_function(fs::FakeSolution) + t -> t .* ones(length(fs.u[1])) +end +function SymbolicIndexingInterface.get_history_function(fi::FakeIntegrator) + t -> t .* ones(length(fi.u)) +end + +sys = NonMarkovianWrapper(SymbolCache([:x, :y, :z], [:a, :b, :c], :t)) +u0 = [1.0, 2.0, 3.0] +u = [u0 .* i for i in 1:11] +p = [10.0, 20.0, 30.0] +ts = 0.0:0.1:1.0 + +fi = FakeIntegrator(sys, u0, p, ts[1]) +fs = FakeSolution(sys, u, p, ts) +getter = getu(sys, :(x + y)) +@test getter(fi) ≈ 2.8 +@test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11] +@test getter(fs, 1) ≈ 2.8