Skip to content

Commit b8c9101

Browse files
refactor: improve getu performance for vectors involving observed quantities
1 parent dab25be commit b8c9101

File tree

1 file changed

+58
-18
lines changed

1 file changed

+58
-18
lines changed

src/state_indexing.jl

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -189,27 +189,67 @@ for (t1, t2) in [
189189
(NotSymbolic, Union{<:Tuple, <:AbstractArray}),
190190
]
191191
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
192-
getters = getu.((sys,), sym)
193-
_call(getter, args...) = getter(args...)
194-
return let getters = getters, _call = _call
195-
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
196-
function _getter(::Timeseries, prob)
197-
broadcast(i -> map(g -> _call(g, prob, i), getters),
198-
eachindex(state_values(prob)))
199-
end
200-
function _getter(::Timeseries, prob, i)
201-
return map(g -> _call(g, prob, i), getters)
202-
end
192+
num_observed = count(x -> is_observed(sys, x), sym)
193+
if num_observed <= 1
194+
getters = getu.((sys,), sym)
195+
_call(getter, args...) = getter(args...)
196+
return let getters = getters, _call = _call
197+
_getter(::NotTimeseries, prob) = map(g -> g(prob), getters)
198+
function _getter(::Timeseries, prob)
199+
broadcast(i -> map(g -> _call(g, prob, i), getters),
200+
eachindex(state_values(prob)))
201+
end
202+
function _getter(::Timeseries, prob, i)
203+
return map(g -> _call(g, prob, i), getters)
204+
end
203205

204-
# Need another scope for this to not box `_getter`
205-
let _getter = _getter
206-
function getter(prob)
207-
return _getter(is_timeseries(prob), prob)
206+
# Need another scope for this to not box `_getter`
207+
let _getter = _getter
208+
function getter(prob)
209+
return _getter(is_timeseries(prob), prob)
210+
end
211+
function getter(prob, i)
212+
return _getter(is_timeseries(prob), prob, i)
213+
end
214+
getter
215+
end
216+
end
217+
else
218+
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
219+
return let obs = obs, is_tuple = sym isa Tuple
220+
function _getter2(::NotTimeseries, prob)
221+
obs(state_values(prob), parameter_values(prob), current_time(prob))
222+
end
223+
function _getter2(::Timeseries, prob)
224+
obs.(state_values(prob), (parameter_values(prob),), current_time(prob))
208225
end
209-
function getter(prob, i)
210-
return _getter(is_timeseries(prob), prob, i)
226+
function _getter2(::Timeseries, prob, i)
227+
obs(state_values(prob, i),
228+
parameter_values(prob),
229+
current_time(prob, i))
230+
end
231+
232+
if is_tuple
233+
let _getter2 = _getter2
234+
function getter2(prob)
235+
Tuple(_getter2(is_timeseries(prob), prob))
236+
end
237+
function getter2(prob, i)
238+
Tuple(_getter2(is_timeseries(prob), prob, i))
239+
end
240+
getter2
241+
end
242+
else
243+
let _getter2 = _getter2
244+
function getter3(prob)
245+
_getter2(is_timeseries(prob), prob)
246+
end
247+
function getter3(prob, i)
248+
_getter2(is_timeseries(prob), prob, i)
249+
end
250+
getter3
251+
end
211252
end
212-
getter
213253
end
214254
end
215255
end

0 commit comments

Comments
 (0)