diff --git a/src/parameter_indexing.jl b/src/parameter_indexing.jl index de21a38..b441070 100644 --- a/src/parameter_indexing.jl +++ b/src/parameter_indexing.jl @@ -531,27 +531,39 @@ function (mpg::OnlyTimeseriesMPG)( throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg)) end -struct AsParameterTupleWrapper{N, G <: AbstractParameterGetIndexer} <: +struct AsParameterTupleWrapper{N, A, G <: AbstractParameterGetIndexer} <: AbstractParameterGetIndexer getter::G end -AsParameterTupleWrapper{N}(getter::G) where {N, G} = AsParameterTupleWrapper{N, G}(getter) +function AsParameterTupleWrapper{N}(getter::G) where {N, G} + AsParameterTupleWrapper{N, Nothing, G}(getter) +end +function AsParameterTupleWrapper{N, A}(getter::G) where {N, A, G} + AsParameterTupleWrapper{N, A, G}(getter) +end -function is_indexer_timeseries(::Type{AsParameterTupleWrapper{N, G}}) where {N, G} +function is_indexer_timeseries(::Type{AsParameterTupleWrapper{N, A, G}}) where {N, A, G} is_indexer_timeseries(G) end function indexer_timeseries_index(atw::AsParameterTupleWrapper) indexer_timeseries_index(atw.getter) end -function as_timeseries_indexer(::IndexerBoth, atw::AsParameterTupleWrapper{N}) where {N} - AsParameterTupleWrapper{N}(as_timeseries_indexer(atw.getter)) +function as_timeseries_indexer( + ::IndexerBoth, atw::AsParameterTupleWrapper{N, A}) where {N, A} + AsParameterTupleWrapper{N, A}(as_timeseries_indexer(atw.getter)) end -function as_not_timeseries_indexer(::IndexerBoth, atw::AsParameterTupleWrapper{N}) where {N} - AsParameterTupleWrapper{N}(as_not_timeseries_indexer(atw.getter)) +function as_not_timeseries_indexer( + ::IndexerBoth, atw::AsParameterTupleWrapper{N, A}) where {N, A} + AsParameterTupleWrapper{N, A}(as_not_timeseries_indexer(atw.getter)) end -wrap_tuple(::AsParameterTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Val(N)) +function wrap_tuple(::AsParameterTupleWrapper{N, Nothing}, val) where {N} + ntuple(i -> val[i], Val(N)) +end +function wrap_tuple(::AsParameterTupleWrapper{N, A}, val) where {N, A} + NamedTuple{A}(ntuple(i -> val[i], Val(N))) +end function (atw::AsParameterTupleWrapper)(ts::IsTimeseriesTrait, prob, args...) atw(ts, is_indexer_timeseries(atw), prob, args...) @@ -591,19 +603,24 @@ is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg. for (t1, t2) in [ (ArraySymbolic, Any), (ScalarSymbolic, Any), - (NotSymbolic, Union{<:Tuple, <:AbstractArray}) + (NotSymbolic, Union{<:Tuple, <:NamedTuple, <:AbstractArray}) ] @eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2) # We need to do it this way because if an `ODESystem` has `p[1], p[2], p[3]` as # parameters (all scalarized) then `is_observed(sys, p[2:3]) == true`. Then, # `getp` errors on older MTK that doesn't support `parameter_observed`. - getters = getp.((sys,), p) + _p = p isa NamedTuple ? Tuple(p) : p + getters = getp.((sys,), _p) num_observed = count(is_observed_getter, getters) supports_tuple = supports_tuple_observed(sys) - p_arr = p isa Tuple ? collect(p) : p + p_arr = p isa Union{Tuple, NamedTuple} ? collect(p) : p if num_observed == 0 - return MultipleParametersGetter(getters) + getter = MultipleParametersGetter(getters) + if p isa NamedTuple + getter = AsParameterTupleWrapper{length(p), keys(p)}(getter) + end + return getter else pofn = supports_tuple ? parameter_observed(sys, p) : parameter_observed(sys, p_arr) @@ -617,8 +634,12 @@ for (t1, t2) in [ else getter = GetParameterObservedNoTime(pofn) end - return p isa Tuple && !supports_tuple ? - AsParameterTupleWrapper{length(p)}(getter) : getter + if p isa Tuple && !supports_tuple + getter = AsParameterTupleWrapper{length(p)}(getter) + elseif p isa NamedTuple + getter = AsParameterTupleWrapper{length(p), keys(p)}(getter) + end + return getter end end end @@ -698,9 +719,13 @@ end for (t1, t2) in [ (ArraySymbolic, Any), (ScalarSymbolic, Any), - (NotSymbolic, Union{<:Tuple, <:AbstractArray}) + (NotSymbolic, Union{<:Tuple, <:NamedTuple, <:AbstractArray}) ] @eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2) + if p isa NamedTuple + setters = NamedTuple{keys(p)}(setp.((sys,), values(p); run_hook = false)) + return NamedTupleSetter(setters) + end setters = setp.((sys,), p; run_hook = false) return MultipleSetters(setters) end diff --git a/src/state_indexing.jl b/src/state_indexing.jl index dc381ea..a8bb306 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -221,13 +221,17 @@ function (mg::MultipleGetters)(::NotTimeseries, ::IndexerMixedTimeseries, prob, return map(g -> g(prob), mg.getters) end -struct AsTupleWrapper{N, G} <: AbstractStateGetIndexer +struct AsTupleWrapper{N, A, G} <: AbstractStateGetIndexer getter::G end -AsTupleWrapper{N}(getter::G) where {N, G} = AsTupleWrapper{N, G}(getter) +AsTupleWrapper{N}(getter::G) where {N, G} = AsTupleWrapper{N, Nothing, G}(getter) +AsTupleWrapper{N, A}(getter::G) where {N, A, G} = AsTupleWrapper{N, A, G}(getter) -wrap_tuple(::AsTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Val(N)) +wrap_tuple(::AsTupleWrapper{N, Nothing}, val) where {N} = ntuple(i -> val[i], Val(N)) +function wrap_tuple(::AsTupleWrapper{N, A}, val) where {N, A} + NamedTuple{A}(ntuple(i -> val[i], Val(N))) +end function (atw::AsTupleWrapper)(::Timeseries, prob) return wrap_tuple.((atw,), atw.getter(prob)) @@ -245,13 +249,13 @@ end for (t1, t2) in [ (ScalarSymbolic, Any), (ArraySymbolic, Any), - (NotSymbolic, Union{<:Tuple, <:AbstractArray}) + (NotSymbolic, Union{<:Tuple, <:NamedTuple, <:AbstractArray}) ] @eval function _getsym(sys, ::NotSymbolic, elt::$t1, sym::$t2) if isempty(sym) return MultipleGetters(ContinuousTimeseries(), sym) end - sym_arr = sym isa Tuple ? collect(sym) : sym + sym_arr = sym isa Union{Tuple, NamedTuple} ? collect(sym) : sym supports_tuple = supports_tuple_observed(sys) num_observed = 0 for s in sym @@ -266,6 +270,8 @@ for (t1, t2) in [ getter = TimeIndependentObservedFunction(obs) if sym isa Tuple getter = AsTupleWrapper{length(sym)}(getter) + elseif sym isa NamedTuple + getter = AsTupleWrapper{length(sym), keys(sym)}(getter) end return getter end @@ -280,9 +286,14 @@ for (t1, t2) in [ ts_idxs = collect(ts_idxs) end - if num_observed == 0 || num_observed == 1 && sym isa Tuple - getters = getsym.((sys,), sym) - return MultipleGetters(ts_idxs, getters) + if num_observed == 0 || num_observed == 1 && sym isa Union{Tuple, NamedTuple} + _sym = sym isa NamedTuple ? Tuple(sym) : sym + getters = getsym.((sys,), _sym) + getter = MultipleGetters(ts_idxs, getters) + if sym isa NamedTuple + getter = AsTupleWrapper{length(sym), keys(sym)}(getter) + end + return getter else obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr) getter = if is_time_dependent(sys) @@ -292,6 +303,8 @@ for (t1, t2) in [ end if sym isa Tuple && !supports_tuple getter = AsTupleWrapper{length(sym)}(getter) + elseif sym isa NamedTuple + getter = AsTupleWrapper{length(sym), keys(sym)}(getter) end return getter end @@ -351,12 +364,39 @@ function _setsym(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym) error("Invalid symbol $sym for `setsym`") end +struct NamedTupleSetter{S <: NamedTuple} <: AbstractSetIndexer + setter::S +end + +function (nts::NamedTupleSetter)(prob, val) + _generated_setter(nts, prob, val) +end + +@generated function _generated_setter( + nts::NamedTupleSetter{<:NamedTuple{N1}}, prob, val::NamedTuple{N2}) where {N1, N2} + expr = Expr(:block) + for (i, name) in enumerate(N2) + idx = findfirst(isequal(name), N1) + if idx === nothing + throw(ArgumentError(""" + Invalid name $(name) in value. Must be one of $(N1). + """)) + end + push!(expr.args, :(nts.setter[$idx](prob, val[$i]))) + end + return expr +end + for (t1, t2) in [ (ScalarSymbolic, Any), (ArraySymbolic, Any), - (NotSymbolic, Union{<:Tuple, <:AbstractArray}) + (NotSymbolic, Union{<:Tuple, <:NamedTuple, <:AbstractArray}) ] @eval function _setsym(sys, ::NotSymbolic, ::$t1, sym::$t2) + if sym isa NamedTuple + setters = NamedTuple{keys(sym)}(setsym.((sys,), values(sym))) + return NamedTupleSetter(setters) + end setters = setsym.((sys,), sym) return MultipleSetters(setters) end diff --git a/test/parameter_indexing_test.jl b/test/parameter_indexing_test.jl index 7017b74..b3d6e73 100644 --- a/test/parameter_indexing_test.jl +++ b/test/parameter_indexing_test.jl @@ -55,7 +55,10 @@ for sys in [ ([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false), ([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false), ((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true), - ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)] + ((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true), + ((a = :a, b = [:a, :b], c = (d = :c, e = :a)), + (a = p[1], b = p[1:2], c = (d = p[3], e = p[1])), + (a = new_p[1], b = new_p[1:2], c = (d = new_p[3], e = new_p[1])), true)] get = getp(sys, sym) set! = setp(sys, sym) if check_inference diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index e52276f..b2db2f7 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -50,7 +50,12 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) ((1, (:y, :z)), (u[1], (u[2], u[3])), (4.0, (5.0, 6.0)), true) ((:x, [:y], (:z,)), (u[1], [u[2]], (u[3],)), - (4.0, [5.0], (6.0,)), true)] + (4.0, [5.0], (6.0,)), true) + ((a = :x, b = [:x, :y], c = (d = :z, e = :x)), + (a = u[1], b = u[1:2], + c = (d = u[3], e = u[1])), + (a = 4.0, b = [4.0, 5.0], + c = (d = 6.0, e = 4.0)), true)] get = getsym(sys, sym) set! = setsym(sys, sym) if check_inference @@ -86,12 +91,14 @@ for (sym, val, newval, check_inference) in [(:x, u[1], 4.0, true) continue end - setter = setsym_oop(sys, sym) - svals, pvals = setter(fi, newval) - @test svals ≈ new_states - @test pvals == parameter_values(fi) - @test_throws ArgumentError setter(state_values(fi), newval) - @test_throws ArgumentError setter(parameter_values(fi), newval) + if !(sym isa NamedTuple) + setter = setsym_oop(sys, sym) + svals, pvals = setter(fi, newval) + @test svals ≈ new_states + @test pvals == parameter_values(fi) + @test_throws ArgumentError setter(state_values(fi), newval) + @test_throws ArgumentError setter(parameter_values(fi), newval) + end end for (sym, val, check_inference) in [