Skip to content

Commit 3f6cd4b

Browse files
Merge pull request #121 from SciML/as/getu-nt
feat: support `NamedTuple` getters and setters
2 parents 6089047 + cc67b11 commit 3f6cd4b

File tree

4 files changed

+107
-32
lines changed

4 files changed

+107
-32
lines changed

src/parameter_indexing.jl

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -531,27 +531,39 @@ function (mpg::OnlyTimeseriesMPG)(
531531
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg))
532532
end
533533

534-
struct AsParameterTupleWrapper{N, G <: AbstractParameterGetIndexer} <:
534+
struct AsParameterTupleWrapper{N, A, G <: AbstractParameterGetIndexer} <:
535535
AbstractParameterGetIndexer
536536
getter::G
537537
end
538538

539-
AsParameterTupleWrapper{N}(getter::G) where {N, G} = AsParameterTupleWrapper{N, G}(getter)
539+
function AsParameterTupleWrapper{N}(getter::G) where {N, G}
540+
AsParameterTupleWrapper{N, Nothing, G}(getter)
541+
end
542+
function AsParameterTupleWrapper{N, A}(getter::G) where {N, A, G}
543+
AsParameterTupleWrapper{N, A, G}(getter)
544+
end
540545

541-
function is_indexer_timeseries(::Type{AsParameterTupleWrapper{N, G}}) where {N, G}
546+
function is_indexer_timeseries(::Type{AsParameterTupleWrapper{N, A, G}}) where {N, A, G}
542547
is_indexer_timeseries(G)
543548
end
544549
function indexer_timeseries_index(atw::AsParameterTupleWrapper)
545550
indexer_timeseries_index(atw.getter)
546551
end
547-
function as_timeseries_indexer(::IndexerBoth, atw::AsParameterTupleWrapper{N}) where {N}
548-
AsParameterTupleWrapper{N}(as_timeseries_indexer(atw.getter))
552+
function as_timeseries_indexer(
553+
::IndexerBoth, atw::AsParameterTupleWrapper{N, A}) where {N, A}
554+
AsParameterTupleWrapper{N, A}(as_timeseries_indexer(atw.getter))
549555
end
550-
function as_not_timeseries_indexer(::IndexerBoth, atw::AsParameterTupleWrapper{N}) where {N}
551-
AsParameterTupleWrapper{N}(as_not_timeseries_indexer(atw.getter))
556+
function as_not_timeseries_indexer(
557+
::IndexerBoth, atw::AsParameterTupleWrapper{N, A}) where {N, A}
558+
AsParameterTupleWrapper{N, A}(as_not_timeseries_indexer(atw.getter))
552559
end
553560

554-
wrap_tuple(::AsParameterTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Val(N))
561+
function wrap_tuple(::AsParameterTupleWrapper{N, Nothing}, val) where {N}
562+
ntuple(i -> val[i], Val(N))
563+
end
564+
function wrap_tuple(::AsParameterTupleWrapper{N, A}, val) where {N, A}
565+
NamedTuple{A}(ntuple(i -> val[i], Val(N)))
566+
end
555567

556568
function (atw::AsParameterTupleWrapper)(ts::IsTimeseriesTrait, prob, args...)
557569
atw(ts, is_indexer_timeseries(atw), prob, args...)
@@ -591,19 +603,24 @@ is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg.
591603
for (t1, t2) in [
592604
(ArraySymbolic, Any),
593605
(ScalarSymbolic, Any),
594-
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
606+
(NotSymbolic, Union{<:Tuple, <:NamedTuple, <:AbstractArray})
595607
]
596608
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
597609
# We need to do it this way because if an `ODESystem` has `p[1], p[2], p[3]` as
598610
# parameters (all scalarized) then `is_observed(sys, p[2:3]) == true`. Then,
599611
# `getp` errors on older MTK that doesn't support `parameter_observed`.
600-
getters = getp.((sys,), p)
612+
_p = p isa NamedTuple ? Tuple(p) : p
613+
getters = getp.((sys,), _p)
601614
num_observed = count(is_observed_getter, getters)
602615
supports_tuple = supports_tuple_observed(sys)
603-
p_arr = p isa Tuple ? collect(p) : p
616+
p_arr = p isa Union{Tuple, NamedTuple} ? collect(p) : p
604617

605618
if num_observed == 0
606-
return MultipleParametersGetter(getters)
619+
getter = MultipleParametersGetter(getters)
620+
if p isa NamedTuple
621+
getter = AsParameterTupleWrapper{length(p), keys(p)}(getter)
622+
end
623+
return getter
607624
else
608625
pofn = supports_tuple ? parameter_observed(sys, p) :
609626
parameter_observed(sys, p_arr)
@@ -617,8 +634,12 @@ for (t1, t2) in [
617634
else
618635
getter = GetParameterObservedNoTime(pofn)
619636
end
620-
return p isa Tuple && !supports_tuple ?
621-
AsParameterTupleWrapper{length(p)}(getter) : getter
637+
if p isa Tuple && !supports_tuple
638+
getter = AsParameterTupleWrapper{length(p)}(getter)
639+
elseif p isa NamedTuple
640+
getter = AsParameterTupleWrapper{length(p), keys(p)}(getter)
641+
end
642+
return getter
622643
end
623644
end
624645
end
@@ -698,9 +719,13 @@ end
698719
for (t1, t2) in [
699720
(ArraySymbolic, Any),
700721
(ScalarSymbolic, Any),
701-
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
722+
(NotSymbolic, Union{<:Tuple, <:NamedTuple, <:AbstractArray})
702723
]
703724
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
725+
if p isa NamedTuple
726+
setters = NamedTuple{keys(p)}(setp.((sys,), values(p); run_hook = false))
727+
return NamedTupleSetter(setters)
728+
end
704729
setters = setp.((sys,), p; run_hook = false)
705730
return MultipleSetters(setters)
706731
end

src/state_indexing.jl

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,17 @@ function (mg::MultipleGetters)(::NotTimeseries, ::IndexerMixedTimeseries, prob,
221221
return map(g -> g(prob), mg.getters)
222222
end
223223

224-
struct AsTupleWrapper{N, G} <: AbstractStateGetIndexer
224+
struct AsTupleWrapper{N, A, G} <: AbstractStateGetIndexer
225225
getter::G
226226
end
227227

228-
AsTupleWrapper{N}(getter::G) where {N, G} = AsTupleWrapper{N, G}(getter)
228+
AsTupleWrapper{N}(getter::G) where {N, G} = AsTupleWrapper{N, Nothing, G}(getter)
229+
AsTupleWrapper{N, A}(getter::G) where {N, A, G} = AsTupleWrapper{N, A, G}(getter)
229230

230-
wrap_tuple(::AsTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Val(N))
231+
wrap_tuple(::AsTupleWrapper{N, Nothing}, val) where {N} = ntuple(i -> val[i], Val(N))
232+
function wrap_tuple(::AsTupleWrapper{N, A}, val) where {N, A}
233+
NamedTuple{A}(ntuple(i -> val[i], Val(N)))
234+
end
231235

232236
function (atw::AsTupleWrapper)(::Timeseries, prob)
233237
return wrap_tuple.((atw,), atw.getter(prob))
@@ -245,13 +249,13 @@ end
245249
for (t1, t2) in [
246250
(ScalarSymbolic, Any),
247251
(ArraySymbolic, Any),
248-
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
252+
(NotSymbolic, Union{<:Tuple, <:NamedTuple, <:AbstractArray})
249253
]
250254
@eval function _getsym(sys, ::NotSymbolic, elt::$t1, sym::$t2)
251255
if isempty(sym)
252256
return MultipleGetters(ContinuousTimeseries(), sym)
253257
end
254-
sym_arr = sym isa Tuple ? collect(sym) : sym
258+
sym_arr = sym isa Union{Tuple, NamedTuple} ? collect(sym) : sym
255259
supports_tuple = supports_tuple_observed(sys)
256260
num_observed = 0
257261
for s in sym
@@ -266,6 +270,8 @@ for (t1, t2) in [
266270
getter = TimeIndependentObservedFunction(obs)
267271
if sym isa Tuple
268272
getter = AsTupleWrapper{length(sym)}(getter)
273+
elseif sym isa NamedTuple
274+
getter = AsTupleWrapper{length(sym), keys(sym)}(getter)
269275
end
270276
return getter
271277
end
@@ -280,9 +286,14 @@ for (t1, t2) in [
280286
ts_idxs = collect(ts_idxs)
281287
end
282288

283-
if num_observed == 0 || num_observed == 1 && sym isa Tuple
284-
getters = getsym.((sys,), sym)
285-
return MultipleGetters(ts_idxs, getters)
289+
if num_observed == 0 || num_observed == 1 && sym isa Union{Tuple, NamedTuple}
290+
_sym = sym isa NamedTuple ? Tuple(sym) : sym
291+
getters = getsym.((sys,), _sym)
292+
getter = MultipleGetters(ts_idxs, getters)
293+
if sym isa NamedTuple
294+
getter = AsTupleWrapper{length(sym), keys(sym)}(getter)
295+
end
296+
return getter
286297
else
287298
obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr)
288299
getter = if is_time_dependent(sys)
@@ -292,6 +303,8 @@ for (t1, t2) in [
292303
end
293304
if sym isa Tuple && !supports_tuple
294305
getter = AsTupleWrapper{length(sym)}(getter)
306+
elseif sym isa NamedTuple
307+
getter = AsTupleWrapper{length(sym), keys(sym)}(getter)
295308
end
296309
return getter
297310
end
@@ -351,12 +364,39 @@ function _setsym(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
351364
error("Invalid symbol $sym for `setsym`")
352365
end
353366

367+
struct NamedTupleSetter{S <: NamedTuple} <: AbstractSetIndexer
368+
setter::S
369+
end
370+
371+
function (nts::NamedTupleSetter)(prob, val)
372+
_generated_setter(nts, prob, val)
373+
end
374+
375+
@generated function _generated_setter(
376+
nts::NamedTupleSetter{<:NamedTuple{N1}}, prob, val::NamedTuple{N2}) where {N1, N2}
377+
expr = Expr(:block)
378+
for (i, name) in enumerate(N2)
379+
idx = findfirst(isequal(name), N1)
380+
if idx === nothing
381+
throw(ArgumentError("""
382+
Invalid name $(name) in value. Must be one of $(N1).
383+
"""))
384+
end
385+
push!(expr.args, :(nts.setter[$idx](prob, val[$i])))
386+
end
387+
return expr
388+
end
389+
354390
for (t1, t2) in [
355391
(ScalarSymbolic, Any),
356392
(ArraySymbolic, Any),
357-
(NotSymbolic, Union{<:Tuple, <:AbstractArray})
393+
(NotSymbolic, Union{<:Tuple, <:NamedTuple, <:AbstractArray})
358394
]
359395
@eval function _setsym(sys, ::NotSymbolic, ::$t1, sym::$t2)
396+
if sym isa NamedTuple
397+
setters = NamedTuple{keys(sym)}(setsym.((sys,), values(sym)))
398+
return NamedTupleSetter(setters)
399+
end
360400
setters = setsym.((sys,), sym)
361401
return MultipleSetters(setters)
362402
end

test/parameter_indexing_test.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ for sys in [
5555
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
5656
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
5757
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
58-
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)]
58+
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
59+
((a = :a, b = [:a, :b], c = (d = :c, e = :a)),
60+
(a = p[1], b = p[1:2], c = (d = p[3], e = p[1])),
61+
(a = new_p[1], b = new_p[1:2], c = (d = new_p[3], e = new_p[1])), true)]
5962
get = getp(sys, sym)
6063
set! = setp(sys, sym)
6164
if check_inference

test/state_indexing_test.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
5050
((1, (:y, :z)), (u[1], (u[2], u[3])),
5151
(4.0, (5.0, 6.0)), true)
5252
((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)),
53-
(4.0, [5.0], (6.0,)), true)]
53+
(4.0, [5.0], (6.0,)), true)
54+
((a = :x, b = [:x, :y], c = (d = :z, e = :x)),
55+
(a = u[1], b = u[1:2],
56+
c = (d = u[3], e = u[1])),
57+
(a = 4.0, b = [4.0, 5.0],
58+
c = (d = 6.0, e = 4.0)), true)]
5459
get = getsym(sys, sym)
5560
set! = setsym(sys, sym)
5661
if check_inference
@@ -86,12 +91,14 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true)
8691
continue
8792
end
8893

89-
setter = setsym_oop(sys, sym)
90-
svals, pvals = setter(fi, newval)
91-
@test svals new_states
92-
@test pvals == parameter_values(fi)
93-
@test_throws ArgumentError setter(state_values(fi), newval)
94-
@test_throws ArgumentError setter(parameter_values(fi), newval)
94+
if !(sym isa NamedTuple)
95+
setter = setsym_oop(sys, sym)
96+
svals, pvals = setter(fi, newval)
97+
@test svals new_states
98+
@test pvals == parameter_values(fi)
99+
@test_throws ArgumentError setter(state_values(fi), newval)
100+
@test_throws ArgumentError setter(parameter_values(fi), newval)
101+
end
95102
end
96103

97104
for (sym, val, check_inference) in [

0 commit comments

Comments
 (0)