@@ -531,27 +531,39 @@ function (mpg::OnlyTimeseriesMPG)(
531531 throw (ParameterTimeseriesValueIndexMismatchError {NotTimeseries} (prob, mpg))
532532end
533533
534- struct AsParameterTupleWrapper{N, G <: AbstractParameterGetIndexer } < :
534+ struct AsParameterTupleWrapper{N, A, G <: AbstractParameterGetIndexer } < :
535535 AbstractParameterGetIndexer
536536 getter:: G
537537end
538538
539- AsParameterTupleWrapper {N} (getter:: G ) where {N, G} = AsParameterTupleWrapper {N, G} (getter)
539+ function AsParameterTupleWrapper {N} (getter:: G ) where {N, G}
540+ AsParameterTupleWrapper {N, Nothing, G} (getter)
541+ end
542+ function AsParameterTupleWrapper {N, A} (getter:: G ) where {N, A, G}
543+ AsParameterTupleWrapper {N, A, G} (getter)
544+ end
540545
541- function is_indexer_timeseries (:: Type{AsParameterTupleWrapper{N, G}} ) where {N, G}
546+ function is_indexer_timeseries (:: Type{AsParameterTupleWrapper{N, A, G}} ) where {N, A , G}
542547 is_indexer_timeseries (G)
543548end
544549function indexer_timeseries_index (atw:: AsParameterTupleWrapper )
545550 indexer_timeseries_index (atw. getter)
546551end
547- function as_timeseries_indexer (:: IndexerBoth , atw:: AsParameterTupleWrapper{N} ) where {N}
548- AsParameterTupleWrapper {N} (as_timeseries_indexer (atw. getter))
552+ function as_timeseries_indexer (
553+ :: IndexerBoth , atw:: AsParameterTupleWrapper{N, A} ) where {N, A}
554+ AsParameterTupleWrapper {N, A} (as_timeseries_indexer (atw. getter))
549555end
550- function as_not_timeseries_indexer (:: IndexerBoth , atw:: AsParameterTupleWrapper{N} ) where {N}
551- AsParameterTupleWrapper {N} (as_not_timeseries_indexer (atw. getter))
556+ function as_not_timeseries_indexer (
557+ :: IndexerBoth , atw:: AsParameterTupleWrapper{N, A} ) where {N, A}
558+ AsParameterTupleWrapper {N, A} (as_not_timeseries_indexer (atw. getter))
552559end
553560
554- wrap_tuple (:: AsParameterTupleWrapper{N} , val) where {N} = ntuple (i -> val[i], Val (N))
561+ function wrap_tuple (:: AsParameterTupleWrapper{N, Nothing} , val) where {N}
562+ ntuple (i -> val[i], Val (N))
563+ end
564+ function wrap_tuple (:: AsParameterTupleWrapper{N, A} , val) where {N, A}
565+ NamedTuple {A} (ntuple (i -> val[i], Val (N)))
566+ end
555567
556568function (atw:: AsParameterTupleWrapper )(ts:: IsTimeseriesTrait , prob, args... )
557569 atw (ts, is_indexer_timeseries (atw), prob, args... )
@@ -591,19 +603,24 @@ is_observed_getter(mpg::MultipleParametersGetter) = any(is_observed_getter, mpg.
591603for (t1, t2) in [
592604 (ArraySymbolic, Any),
593605 (ScalarSymbolic, Any),
594- (NotSymbolic, Union{<: Tuple , <: AbstractArray })
606+ (NotSymbolic, Union{<: Tuple , <: NamedTuple , <: AbstractArray })
595607]
596608 @eval function _getp (sys, :: NotSymbolic , :: $t1 , p:: $t2 )
597609 # We need to do it this way because if an `ODESystem` has `p[1], p[2], p[3]` as
598610 # parameters (all scalarized) then `is_observed(sys, p[2:3]) == true`. Then,
599611 # `getp` errors on older MTK that doesn't support `parameter_observed`.
600- getters = getp .((sys,), p)
612+ _p = p isa NamedTuple ? Tuple (p) : p
613+ getters = getp .((sys,), _p)
601614 num_observed = count (is_observed_getter, getters)
602615 supports_tuple = supports_tuple_observed (sys)
603- p_arr = p isa Tuple ? collect (p) : p
616+ p_arr = p isa Union{ Tuple, NamedTuple} ? collect (p) : p
604617
605618 if num_observed == 0
606- return MultipleParametersGetter (getters)
619+ getter = MultipleParametersGetter (getters)
620+ if p isa NamedTuple
621+ getter = AsParameterTupleWrapper {length(p), keys(p)} (getter)
622+ end
623+ return getter
607624 else
608625 pofn = supports_tuple ? parameter_observed (sys, p) :
609626 parameter_observed (sys, p_arr)
@@ -617,8 +634,12 @@ for (t1, t2) in [
617634 else
618635 getter = GetParameterObservedNoTime (pofn)
619636 end
620- return p isa Tuple && ! supports_tuple ?
621- AsParameterTupleWrapper {length(p)} (getter) : getter
637+ if p isa Tuple && ! supports_tuple
638+ getter = AsParameterTupleWrapper {length(p)} (getter)
639+ elseif p isa NamedTuple
640+ getter = AsParameterTupleWrapper {length(p), keys(p)} (getter)
641+ end
642+ return getter
622643 end
623644 end
624645end
698719for (t1, t2) in [
699720 (ArraySymbolic, Any),
700721 (ScalarSymbolic, Any),
701- (NotSymbolic, Union{<: Tuple , <: AbstractArray })
722+ (NotSymbolic, Union{<: Tuple , <: NamedTuple , <: AbstractArray })
702723]
703724 @eval function _setp (sys, :: NotSymbolic , :: $t1 , p:: $t2 )
725+ if p isa NamedTuple
726+ setters = NamedTuple {keys(p)} (setp .((sys,), values (p); run_hook = false ))
727+ return NamedTupleSetter (setters)
728+ end
704729 setters = setp .((sys,), p; run_hook = false )
705730 return MultipleSetters (setters)
706731 end
0 commit comments