@@ -140,18 +140,34 @@ end
140140
141141Base. copy (a:: AbstractNamedDimsArray ) = nameddims (copy (dename (a)), nameddimsindices (a))
142142
143+ const NamedDimsIndices = Union{
144+ AbstractNamedUnitRange{<: Integer },AbstractNamedArray{<: Integer }
145+ }
146+ const NamedDimsAxis = AbstractNamedUnitRange{
147+ <: Integer ,<: AbstractUnitRange ,<: NamedDimsIndices
148+ }
149+
143150# Generic constructor.
144151function nameddims (a:: AbstractArray , nameddimsindices)
145152 # TODO : Check the shape of `nameddimsindices` matches the shape of `a`.
146- return nameddimstype (eltype (nameddimsindices))(
147- a, to_nameddimsindices (a, nameddimsindices)
148- )
153+ arrtype = mapreduce (nameddimsarraytype, combine_nameddimsarraytype, nameddimsindices)
154+ return arrtype (a, to_nameddimsindices (a, nameddimsindices))
149155end
150156
151157# Can overload this to get custom named dims array wrapper
152158# depending on the dimension name types, for example
153159# output an `ITensor` if the dimension names are `IndexName`s.
154- nameddimstype (dimnametype:: Type ) = NamedDimsArray
160+ nameddimsarraytype (nameddim) = nameddimsarraytype (typeof (nameddim))
161+ nameddimsarraytype (nameddimtype:: Type ) = NamedDimsArray
162+ function nameddimsarraytype (nameddimtype:: Type{<:NamedDimsIndices} )
163+ return nameddimsarraytype (nametype (nameddimtype))
164+ end
165+ function combine_nameddimsarraytype (
166+ :: Type{<:AbstractNamedDimsArray} , :: Type{<:AbstractNamedDimsArray}
167+ )
168+ return NamedDimsArray
169+ end
170+ combine_nameddimsarraytype (:: Type{T} , :: Type{T} ) where {T<: AbstractNamedDimsArray } = T
155171
156172Base. axes (a:: AbstractNamedDimsArray ) = map (named, axes (dename (a)), nameddimsindices (a))
157173Base. size (a:: AbstractNamedDimsArray ) = map (named, size (dename (a)), nameddimsindices (a))
@@ -175,13 +191,6 @@ Base.eltype(a::AbstractNamedDimsArray) = eltype(dename(a))
175191Base. axes (a:: AbstractNamedDimsArray , dimname:: Name ) = axes (a, dim (a, dimname))
176192Base. size (a:: AbstractNamedDimsArray , dimname:: Name ) = size (a, dim (a, dimname))
177193
178- const NamedDimsIndices = Union{
179- AbstractNamedUnitRange{<: Integer },AbstractNamedArray{<: Integer }
180- }
181- const NamedDimsAxis = AbstractNamedUnitRange{
182- <: Integer ,<: AbstractUnitRange ,<: NamedDimsIndices
183- }
184-
185194to_nameddimsaxes (dims) = map (to_nameddimsaxis, dims)
186195to_nameddimsaxis (ax:: NamedDimsAxis ) = ax
187196to_nameddimsaxis (I:: NamedDimsIndices ) = named (dename (only (axes (I))), I)
0 commit comments