@@ -140,18 +140,34 @@ end
140
140
141
141
Base. copy (a:: AbstractNamedDimsArray ) = nameddims (copy (dename (a)), nameddimsindices (a))
142
142
143
+ const NamedDimsIndices = Union{
144
+ AbstractNamedUnitRange{<: Integer },AbstractNamedArray{<: Integer }
145
+ }
146
+ const NamedDimsAxis = AbstractNamedUnitRange{
147
+ <: Integer ,<: AbstractUnitRange ,<: NamedDimsIndices
148
+ }
149
+
143
150
# Generic constructor.
144
151
function nameddims (a:: AbstractArray , nameddimsindices)
145
152
# TODO : Check the shape of `nameddimsindices` matches the shape of `a`.
146
- return nameddimstype (eltype (nameddimsindices))(
147
- a, to_nameddimsindices (a, nameddimsindices)
148
- )
153
+ arrtype = mapreduce (nameddimsarraytype, combine_nameddimsarraytype, nameddimsindices)
154
+ return arrtype (a, to_nameddimsindices (a, nameddimsindices))
149
155
end
150
156
151
157
# Can overload this to get custom named dims array wrapper
152
158
# depending on the dimension name types, for example
153
159
# output an `ITensor` if the dimension names are `IndexName`s.
154
- nameddimstype (dimnametype:: Type ) = NamedDimsArray
160
+ nameddimsarraytype (nameddim) = nameddimsarraytype (typeof (nameddim))
161
+ nameddimsarraytype (nameddimtype:: Type ) = NamedDimsArray
162
+ function nameddimsarraytype (nameddimtype:: Type{<:NamedDimsIndices} )
163
+ return nameddimsarraytype (nametype (nameddimtype))
164
+ end
165
+ function combine_nameddimsarraytype (
166
+ :: Type{<:AbstractNamedDimsArray} , :: Type{<:AbstractNamedDimsArray}
167
+ )
168
+ return NamedDimsArray
169
+ end
170
+ combine_nameddimsarraytype (:: Type{T} , :: Type{T} ) where {T<: AbstractNamedDimsArray } = T
155
171
156
172
Base. axes (a:: AbstractNamedDimsArray ) = map (named, axes (dename (a)), nameddimsindices (a))
157
173
Base. size (a:: AbstractNamedDimsArray ) = map (named, size (dename (a)), nameddimsindices (a))
@@ -175,13 +191,6 @@ Base.eltype(a::AbstractNamedDimsArray) = eltype(dename(a))
175
191
Base. axes (a:: AbstractNamedDimsArray , dimname:: Name ) = axes (a, dim (a, dimname))
176
192
Base. size (a:: AbstractNamedDimsArray , dimname:: Name ) = size (a, dim (a, dimname))
177
193
178
- const NamedDimsIndices = Union{
179
- AbstractNamedUnitRange{<: Integer },AbstractNamedArray{<: Integer }
180
- }
181
- const NamedDimsAxis = AbstractNamedUnitRange{
182
- <: Integer ,<: AbstractUnitRange ,<: NamedDimsIndices
183
- }
184
-
185
194
to_nameddimsaxes (dims) = map (to_nameddimsaxis, dims)
186
195
to_nameddimsaxis (ax:: NamedDimsAxis ) = ax
187
196
to_nameddimsaxis (I:: NamedDimsIndices ) = named (dename (only (axes (I))), I)
0 commit comments