Skip to content

Commit 124a6b8

Browse files
Merge pull request #107 from SciML/as/tuple-observed
feat: add support for directly generating tuple observed functions
2 parents 585be57 + 5cea80e commit 124a6b8

File tree

4 files changed

+59
-5
lines changed

4 files changed

+59
-5
lines changed

src/index_provider_interface.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,24 @@ See also: [`is_time_dependent`](@ref), [`is_markovian`](@ref), [`constant_struct
212212
observed(indp, sym) = observed(symbolic_container(indp), sym)
213213
observed(indp, sym, states) = observed(symbolic_container(indp), sym, states)
214214

215+
"""
216+
supports_tuple_observed(indp)
217+
218+
Check if the given index provider supports generating observed functions for tuples of
219+
symbolic variables. Falls back using `symbolic_container`, and returns `false` by
220+
default.
221+
222+
See also: [`observed`](@ref), [`parameter_observed`](@ref), [`symbolic_container`](@ref).
223+
"""
224+
function supports_tuple_observed(indp)
225+
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
226+
(sc = symbolic_container(indp)) !== indp
227+
supports_tuple_observed(sc)
228+
else
229+
false
230+
end
231+
end
232+
215233
"""
216234
is_time_dependent(indp)
217235

src/parameter_indexing.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,12 +599,14 @@ for (t1, t2) in [
599599
# `getp` errors on older MTK that doesn't support `parameter_observed`.
600600
getters = getp.((sys,), p)
601601
num_observed = count(is_observed_getter, getters)
602+
supports_tuple = supports_tuple_observed(sys)
602603
p_arr = p isa Tuple ? collect(p) : p
603604

604605
if num_observed == 0
605606
return MultipleParametersGetter(getters)
606607
else
607-
pofn = parameter_observed(sys, p_arr)
608+
pofn = supports_tuple ? parameter_observed(sys, p) :
609+
parameter_observed(sys, p_arr)
608610
if pofn === nothing
609611
return MultipleParametersGetter.(getters)
610612
end
@@ -615,7 +617,8 @@ for (t1, t2) in [
615617
else
616618
getter = GetParameterObservedNoTime(pofn)
617619
end
618-
return p isa Tuple ? AsParameterTupleWrapper{length(p)}(getter) : getter
620+
return p isa Tuple && !supports_tuple ?
621+
AsParameterTupleWrapper{length(p)}(getter) : getter
619622
end
620623
end
621624
end

src/state_indexing.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ for (t1, t2) in [
252252
return MultipleGetters(ContinuousTimeseries(), sym)
253253
end
254254
sym_arr = sym isa Tuple ? collect(sym) : sym
255+
supports_tuple = supports_tuple_observed(sys)
255256
num_observed = 0
256257
for s in sym
257258
num_observed += is_observed(sys, s)
@@ -261,7 +262,7 @@ for (t1, t2) in [
261262
if num_observed == 0 || num_observed == 1 && sym isa Tuple
262263
return MultipleGetters(nothing, getsym.((sys,), sym))
263264
else
264-
obs = observed(sys, sym_arr)
265+
obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr)
265266
getter = TimeIndependentObservedFunction(obs)
266267
if sym isa Tuple
267268
getter = AsTupleWrapper{length(sym)}(getter)
@@ -283,13 +284,13 @@ for (t1, t2) in [
283284
getters = getsym.((sys,), sym)
284285
return MultipleGetters(ts_idxs, getters)
285286
else
286-
obs = observed(sys, sym_arr)
287+
obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr)
287288
getter = if is_time_dependent(sys)
288289
TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, obs)
289290
else
290291
TimeIndependentObservedFunction(obs)
291292
end
292-
if sym isa Tuple
293+
if sym isa Tuple && !supports_tuple
293294
getter = AsTupleWrapper{length(sym)}(getter)
294295
end
295296
return getter

test/state_indexing_test.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,35 @@ getter = getsym(sys, :(x + y))
346346
@test getter(fi) 2.8
347347
@test getter(fs) [3.0i + 2(ts[i] - 0.1) for i in 1:11]
348348
@test getter(fs, 1) 2.8
349+
350+
struct TupleObservedWrapper{S}
351+
sys::S
352+
end
353+
SymbolicIndexingInterface.symbolic_container(t::TupleObservedWrapper) = t.sys
354+
SymbolicIndexingInterface.supports_tuple_observed(::TupleObservedWrapper) = true
355+
356+
@testset "Tuple observed" begin
357+
sc = SymbolCache([:x, :y, :z], [:a, :b, :c])
358+
sys = TupleObservedWrapper(sc)
359+
ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3])
360+
getter = getsym(sys, (:(x + y), :(y + z)))
361+
@test all(getter(ps) .≈ (3.0, 5.0))
362+
@test getter(ps) isa Tuple
363+
@test_nowarn @inferred getter(ps)
364+
getter = getsym(sys, (:(a + b), :(b + c)))
365+
@test all(getter(ps) .≈ (0.3, 0.5))
366+
@test getter(ps) isa Tuple
367+
@test_nowarn @inferred getter(ps)
368+
369+
sc = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
370+
sys = TupleObservedWrapper(sc)
371+
ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.1)
372+
getter = getsym(sys, (:(x + y), :(y + t)))
373+
@test all(getter(ps) .≈ (3.0, 2.1))
374+
@test getter(ps) isa Tuple
375+
@test_nowarn @inferred getter(ps)
376+
getter = getsym(sys, (:(a + b), :(b + c)))
377+
@test all(getter(ps) .≈ (0.3, 0.5))
378+
@test getter(ps) isa Tuple
379+
@test_nowarn @inferred getter(ps)
380+
end

0 commit comments

Comments
 (0)