@@ -39,17 +39,22 @@ is_timeseries(::Type) = NotTimeseries()
3939
4040"""
4141 state_values(p)
42+ state_values(p, i)
4243
4344Return an indexable collection containing the values of all states in the integrator or
4445problem `p`. If `is_timeseries(p)` is [`Timeseries`](@ref), return a vector of arrays,
45- each of which contain the state values at the corresponding timestep.
46+ each of which contain the state values at the corresponding timestep. In this case, the
47+ two-argument version of the function can also be implemented to efficiently return
48+ the state values at timestep `i`. By default, the two-argument method calls
49+ `state_values(p)[i]`
4650
4751If this function is called with an `AbstractArray`, it will return the same array.
4852
4953See: [`is_timeseries`](@ref)
5054"""
5155function state_values end
5256state_values (arr:: AbstractArray ) = arr
57+ state_values (arr, i) = state_values (arr)[i]
5358
5459"""
5560 set_state!(sys, val, idx)
6772
6873"""
6974 current_time(p)
75+ current_time(p, i)
7076
7177Return the current time in the integrator or problem `p`. If
7278`is_timeseries(p)` is [`Timeseries`](@ref), return the vector of timesteps at which
73- the state value is saved.
79+ the state value is saved. In this case, the two-argument version of the function can
80+ also be implemented to efficiently return the time at timestep `i`. By default, the two-
81+ argument method calls `current_time(p)[i]`
7482
7583
7684See: [`is_timeseries`](@ref)
7785"""
7886function current_time end
7987
88+ current_time (p, i) = current_time (p)[i]
89+
8090"""
8191 getu(sys, sym)
8292
@@ -85,7 +95,8 @@ the value of the symbolic `sym`. If `sym` is not an observed quantity, the retur
8595function can also directly be called with an array of values representing the state
8696vector. `sym` can be a direct index into the state vector, a symbolic state, a symbolic
8797expression involving symbolic quantities in the system `sys`, or an array/tuple of the
88- aforementioned.
98+ aforementioned. If the returned function is called with a timeseries object, it can also
99+ be given a second argument representing the index at which to find the value of `sym`.
89100
90101At minimum, this requires that the integrator, problem or solution implement
91102[`state_values`](@ref). To support symbolic expressions, the integrator or problem
@@ -103,11 +114,16 @@ end
103114
104115function _getu (sys, :: NotSymbolic , :: NotSymbolic , sym)
105116 _getter (:: Timeseries , prob) = getindex .(state_values (prob), (sym,))
117+ _getter (:: Timeseries , prob, i) = getindex (state_values (prob, i), sym)
106118 _getter (:: NotTimeseries , prob) = state_values (prob)[sym]
107119 return let _getter = _getter
108120 function getter (prob)
109121 return _getter (is_timeseries (prob), prob)
110122 end
123+ function getter (prob, i)
124+ return _getter (is_timeseries (prob), prob, i)
125+ end
126+ getter
111127 end
112128end
113129
@@ -123,6 +139,9 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
123139 (parameter_values (prob),),
124140 current_time (prob))
125141 end
142+ function _getter2 (:: Timeseries , prob, i)
143+ return fn (state_values (prob, i), parameter_values (prob), current_time (prob, i))
144+ end
126145 function _getter2 (:: NotTimeseries , prob)
127146 return fn (state_values (prob), parameter_values (prob), current_time (prob))
128147 end
@@ -131,55 +150,46 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
131150 function getter2 (prob)
132151 return _getter2 (is_timeseries (prob), prob)
133152 end
153+ function getter2 (prob, i)
154+ return _getter2 (is_timeseries (prob), prob, i)
155+ end
156+ getter2
134157 end
135158 else
136- function _getter3 (:: Timeseries , prob)
137- return fn .(state_values (prob), (parameter_values (prob),))
138- end
139- function _getter3 (:: NotTimeseries , prob)
140- return fn (state_values (prob), parameter_values (prob))
141- end
142-
143- return let _getter3 = _getter3
159+ # if there is no time, there is no timeseries
160+ return let fn = fn
144161 function getter3 (prob)
145- return _getter3 ( is_timeseries (prob), prob)
162+ return fn ( state_values (prob), parameter_values ( prob) )
146163 end
147164 end
148165 end
149166 end
150167 error (" Invalid symbol $sym for `getu`" )
151168end
152169
153- struct TimeseriesIndexWrapper{T, I}
154- timeseries:: T
155- idx:: I
156- end
157-
158- state_values (t:: TimeseriesIndexWrapper ) = state_values (t. timeseries)[t. idx]
159- parameter_values (t:: TimeseriesIndexWrapper ) = parameter_values (t. timeseries)
160- current_time (t:: TimeseriesIndexWrapper ) = current_time (t. timeseries)[t. idx]
161-
162170for (t1, t2) in [(ScalarSymbolic, Any), (ArraySymbolic, Any), (NotSymbolic, Union{<: Tuple , <: AbstractArray })]
163171 @eval function _getu (sys, :: NotSymbolic , :: $t1 , sym:: $t2 )
164172 getters = getu .((sys,), sym)
165- _call (getter, prob) = getter (prob)
166-
173+ _call (getter, args... ) = getter (args... )
167174 return let getters = getters, _call = _call
168175 _getter (:: NotTimeseries , prob) = map (g -> g (prob), getters)
169176 function _getter (:: Timeseries , prob)
170- tiws = TimeseriesIndexWrapper .((prob,), eachindex (state_values (prob)))
171- # Ideally this should recursively call `_getter` but that leads to type-instability
172- # since the reference to itself is boxed
173- # Turning this broadcasted `_call` into a map also makes this type-unstable
174-
175- return map (tiw -> _call .(getters, (tiw,)), tiws)
177+ broadcast (i -> map (g -> _call (g, prob, i), getters),
178+ eachindex (state_values (prob)))
179+ end
180+ function _getter (:: Timeseries , prob, i)
181+ return map (g -> _call (g, prob, i), getters)
176182 end
177183
178184 # Need another scope for this to not box `_getter`
179185 let _getter = _getter
180186 function getter (prob)
181187 return _getter (is_timeseries (prob), prob)
182188 end
189+ function getter (prob, i)
190+ return _getter (is_timeseries (prob), prob, i)
191+ end
192+ getter
183193 end
184194 end
185195 end
0 commit comments