Skip to content

Commit 6c2e1b2

Browse files
authored
Improve logic for determining AbstractNamedDimsArray type from dimension names (#11)
1 parent b21c44e commit 6c2e1b2

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractnameddimsarray.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,34 @@ end
140140

141141
Base.copy(a::AbstractNamedDimsArray) = nameddims(copy(dename(a)), nameddimsindices(a))
142142

143+
const NamedDimsIndices = Union{
144+
AbstractNamedUnitRange{<:Integer},AbstractNamedArray{<:Integer}
145+
}
146+
const NamedDimsAxis = AbstractNamedUnitRange{
147+
<:Integer,<:AbstractUnitRange,<:NamedDimsIndices
148+
}
149+
143150
# Generic constructor.
144151
function nameddims(a::AbstractArray, nameddimsindices)
145152
# TODO: Check the shape of `nameddimsindices` matches the shape of `a`.
146-
return nameddimstype(eltype(nameddimsindices))(
147-
a, to_nameddimsindices(a, nameddimsindices)
148-
)
153+
arrtype = mapreduce(nameddimsarraytype, combine_nameddimsarraytype, nameddimsindices)
154+
return arrtype(a, to_nameddimsindices(a, nameddimsindices))
149155
end
150156

151157
# Can overload this to get custom named dims array wrapper
152158
# depending on the dimension name types, for example
153159
# output an `ITensor` if the dimension names are `IndexName`s.
154-
nameddimstype(dimnametype::Type) = NamedDimsArray
160+
nameddimsarraytype(nameddim) = nameddimsarraytype(typeof(nameddim))
161+
nameddimsarraytype(nameddimtype::Type) = NamedDimsArray
162+
function nameddimsarraytype(nameddimtype::Type{<:NamedDimsIndices})
163+
return nameddimsarraytype(nametype(nameddimtype))
164+
end
165+
function combine_nameddimsarraytype(
166+
::Type{<:AbstractNamedDimsArray}, ::Type{<:AbstractNamedDimsArray}
167+
)
168+
return NamedDimsArray
169+
end
170+
combine_nameddimsarraytype(::Type{T}, ::Type{T}) where {T<:AbstractNamedDimsArray} = T
155171

156172
Base.axes(a::AbstractNamedDimsArray) = map(named, axes(dename(a)), nameddimsindices(a))
157173
Base.size(a::AbstractNamedDimsArray) = map(named, size(dename(a)), nameddimsindices(a))
@@ -175,13 +191,6 @@ Base.eltype(a::AbstractNamedDimsArray) = eltype(dename(a))
175191
Base.axes(a::AbstractNamedDimsArray, dimname::Name) = axes(a, dim(a, dimname))
176192
Base.size(a::AbstractNamedDimsArray, dimname::Name) = size(a, dim(a, dimname))
177193

178-
const NamedDimsIndices = Union{
179-
AbstractNamedUnitRange{<:Integer},AbstractNamedArray{<:Integer}
180-
}
181-
const NamedDimsAxis = AbstractNamedUnitRange{
182-
<:Integer,<:AbstractUnitRange,<:NamedDimsIndices
183-
}
184-
185194
to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims)
186195
to_nameddimsaxis(ax::NamedDimsAxis) = ax
187196
to_nameddimsaxis(I::NamedDimsIndices) = named(dename(only(axes(I))), I)

0 commit comments

Comments
 (0)