Skip to content

Commit 94b5a85

Browse files
feat: support indexing Timeseries at specific time in getu
1 parent 15c5f0e commit 94b5a85

File tree

2 files changed

+45
-29
lines changed

2 files changed

+45
-29
lines changed

src/state_indexing.jl

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,22 @@ is_timeseries(::Type) = NotTimeseries()
3939

4040
"""
4141
state_values(p)
42+
state_values(p, i)
4243
4344
Return an indexable collection containing the values of all states in the integrator or
4445
problem `p`. If `is_timeseries(p)` is [`Timeseries`](@ref), return a vector of arrays,
45-
each of which contain the state values at the corresponding timestep.
46+
each of which contain the state values at the corresponding timestep. In this case, the
47+
two-argument version of the function can also be implemented to efficiently return
48+
the state values at timestep `i`. By default, the two-argument method calls
49+
`state_values(p)[i]`
4650
4751
If this function is called with an `AbstractArray`, it will return the same array.
4852
4953
See: [`is_timeseries`](@ref)
5054
"""
5155
function state_values end
5256
state_values(arr::AbstractArray) = arr
57+
state_values(arr, i) = state_values(arr)[i]
5358

5459
"""
5560
set_state!(sys, val, idx)
@@ -67,16 +72,21 @@ end
6772

6873
"""
6974
current_time(p)
75+
current_time(p, i)
7076
7177
Return the current time in the integrator or problem `p`. If
7278
`is_timeseries(p)` is [`Timeseries`](@ref), return the vector of timesteps at which
73-
the state value is saved.
79+
the state value is saved. In this case, the two-argument version of the function can
80+
also be implemented to efficiently return the time at timestep `i`. By default, the two-
81+
argument method calls `current_time(p)[i]`
7482
7583
7684
See: [`is_timeseries`](@ref)
7785
"""
7886
function current_time end
7987

88+
current_time(p, i) = current_time(p)[i]
89+
8090
"""
8191
getu(sys, sym)
8292
@@ -85,7 +95,8 @@ the value of the symbolic `sym`. If `sym` is not an observed quantity, the retur
8595
function can also directly be called with an array of values representing the state
8696
vector. `sym` can be a direct index into the state vector, a symbolic state, a symbolic
8797
expression involving symbolic quantities in the system `sys`, or an array/tuple of the
88-
aforementioned.
98+
aforementioned. If the returned function is called with a timeseries object, it can also
99+
be given a second argument representing the index at which to find the value of `sym`.
89100
90101
At minimum, this requires that the integrator, problem or solution implement
91102
[`state_values`](@ref). To support symbolic expressions, the integrator or problem
@@ -103,11 +114,16 @@ end
103114

104115
function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)
105116
_getter(::Timeseries, prob) = getindex.(state_values(prob), (sym,))
117+
_getter(::Timeseries, prob, i) = getindex(state_values(prob, i), sym)
106118
_getter(::NotTimeseries, prob) = state_values(prob)[sym]
107119
return let _getter = _getter
108120
function getter(prob)
109121
return _getter(is_timeseries(prob), prob)
110122
end
123+
function getter(prob, i)
124+
return _getter(is_timeseries(prob), prob, i)
125+
end
126+
getter
111127
end
112128
end
113129

@@ -123,6 +139,9 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
123139
(parameter_values(prob),),
124140
current_time(prob))
125141
end
142+
function _getter2(::Timeseries, prob, i)
143+
return fn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
144+
end
126145
function _getter2(::NotTimeseries, prob)
127146
return fn(state_values(prob), parameter_values(prob), current_time(prob))
128147
end
@@ -131,55 +150,46 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
131150
function getter2(prob)
132151
return _getter2(is_timeseries(prob), prob)
133152
end
153+
function getter2(prob, i)
154+
return _getter2(is_timeseries(prob), prob, i)
155+
end
156+
getter2
134157
end
135158
else
136-
function _getter3(::Timeseries, prob)
137-
return fn.(state_values(prob), (parameter_values(prob),))
138-
end
139-
function _getter3(::NotTimeseries, prob)
140-
return fn(state_values(prob), parameter_values(prob))
141-
end
142-
143-
return let _getter3 = _getter3
159+
# if there is no time, there is no timeseries
160+
return let fn = fn
144161
function getter3(prob)
145-
return _getter3(is_timeseries(prob), prob)
162+
return fn(state_values(prob), parameter_values(prob))
146163
end
147164
end
148165
end
149166
end
150167
error("Invalid symbol $sym for `getu`")
151168
end
152169

153-
struct TimeseriesIndexWrapper{T, I}
154-
timeseries::T
155-
idx::I
156-
end
157-
158-
state_values(t::TimeseriesIndexWrapper) = state_values(t.timeseries)[t.idx]
159-
parameter_values(t::TimeseriesIndexWrapper) = parameter_values(t.timeseries)
160-
current_time(t::TimeseriesIndexWrapper) = current_time(t.timeseries)[t.idx]
161-
162170
for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<:Tuple, <:AbstractArray})]
163171
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
164172
getters = getu.((sys,), sym)
165-
_call(getter, prob) = getter(prob)
166-
173+
_call(getter, args...) = getter(args...)
167174
return let getters = getters, _call = _call
168175
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
169176
function _getter(::Timeseries, prob)
170-
tiws = TimeseriesIndexWrapper.((prob,), eachindex(state_values(prob)))
171-
# Ideally this should recursively call `_getter` but that leads to type-instability
172-
# since the reference to itself is boxed
173-
# Turning this broadcasted `_call` into a map also makes this type-unstable
174-
175-
return map(tiw -> _call.(getters, (tiw,)), tiws)
177+
broadcast(i -> map(g -> _call(g, prob, i), getters),
178+
eachindex(state_values(prob)))
179+
end
180+
function _getter(::Timeseries, prob, i)
181+
return map(g -> _call(g, prob, i), getters)
176182
end
177183

178184
# Need another scope for this to not box `_getter`
179185
let _getter = _getter
180186
function getter(prob)
181187
return _getter(is_timeseries(prob), prob)
182188
end
189+
function getter(prob, i)
190+
return _getter(is_timeseries(prob), prob, i)
191+
end
192+
getter
183193
end
184194
end
185195
end

test/state_indexing_test.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,10 @@ for (sym, ans, check_inference) in [
103103
@inferred get(sol)
104104
end
105105
@test get(sol) == ans
106+
for i in eachindex(u)
107+
if check_inference
108+
@inferred get(sol, i)
109+
end
110+
@test get(sol, i) == ans[i]
111+
end
106112
end

0 commit comments

Comments
 (0)