Skip to content

Commit ebcc225

Browse files
authored
Preserve wrapper type better (#14)
1 parent d0932f0 commit ebcc225

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractnameddimsarray.jl

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Derive: Derive, @derive, AbstractArrayInterface
2+
using TypeParameterAccessors: unspecify_type_parameters
23

34
# Some of the interface is inspired by:
45
# https://github.com/ITensor/ITensors.jl
@@ -138,7 +139,9 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
138139
return findfirst(==(first(x)), y):findfirst(==(last(x)), y)
139140
end
140141

141-
Base.copy(a::AbstractNamedDimsArray) = nameddims(copy(dename(a)), nameddimsindices(a))
142+
function Base.copy(a::AbstractNamedDimsArray)
143+
return nameddimsarraytype(a)(copy(dename(a)), nameddimsindices(a))
144+
end
142145

143146
const NamedDimsIndices = Union{
144147
AbstractNamedUnitRange{<:Integer},AbstractNamedArray{<:Integer}
@@ -174,28 +177,30 @@ using Base.Broadcast: Broadcasted, Style
174177
struct NaiveOrderedSet{Values}
175178
values::Values
176179
end
177-
Base.Tuple(s::NaiveOrderedSet) = s.values
178-
Base.length(s::NaiveOrderedSet) = length(Tuple(s))
179-
Base.axes(s::NaiveOrderedSet) = axes(Tuple(s))
180-
Base.:(==)(s1::NaiveOrderedSet, s2::NaiveOrderedSet) = issetequal(Tuple(s1), Tuple(s2))
181-
Base.iterate(s::NaiveOrderedSet, args...) = iterate(Tuple(s), args...)
182-
Base.getindex(s::NaiveOrderedSet, I::Int) = Tuple(s)[I]
183-
Base.invperm(s::NaiveOrderedSet) = NaiveOrderedSet(invperm(Tuple(s)))
180+
Base.values(s::NaiveOrderedSet) = s.values
181+
Base.Tuple(s::NaiveOrderedSet) = Tuple(values(s))
182+
Base.length(s::NaiveOrderedSet) = length(values(s))
183+
Base.axes(s::NaiveOrderedSet) = axes(values(s))
184+
Base.:(==)(s1::NaiveOrderedSet, s2::NaiveOrderedSet) = issetequal(values(s1), values(s2))
185+
Base.iterate(s::NaiveOrderedSet, args...) = iterate(values(s), args...)
186+
Base.getindex(s::NaiveOrderedSet, I::Int) = values(s)[I]
187+
Base.invperm(s::NaiveOrderedSet) = NaiveOrderedSet(invperm(values(s)))
184188
Base.Broadcast._axes(::Broadcasted, axes::NaiveOrderedSet) = axes
185189
Base.Broadcast.BroadcastStyle(::Type{<:NaiveOrderedSet}) = Style{NaiveOrderedSet}()
186190
Base.Broadcast.broadcastable(s::NaiveOrderedSet) = s
191+
Base.to_shape(s::NaiveOrderedSet) = s
187192

188193
function Base.copy(
189194
bc::Broadcasted{Style{NaiveOrderedSet},<:Any,<:Any,<:Tuple{<:NaiveOrderedSet}}
190195
)
191-
return NaiveOrderedSet(bc.f.(Tuple(only(bc.args))))
196+
return NaiveOrderedSet(bc.f.(values(only(bc.args))))
192197
end
193198
# Multiple arguments not supported.
194199
function Base.copy(bc::Broadcasted{Style{NaiveOrderedSet}})
195200
return error("This broadcasting expression of `NaiveOrderedSet` is not supported.")
196201
end
197202
function Base.map(f::Function, s::NaiveOrderedSet)
198-
return NaiveOrderedSet(map(f, Tuple(s)))
203+
return NaiveOrderedSet(map(f, values(s)))
199204
end
200205

201206
function Base.axes(a::AbstractNamedDimsArray)
@@ -228,15 +233,36 @@ to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims)
228233
to_nameddimsaxis(ax::NamedDimsAxis) = ax
229234
to_nameddimsaxis(I::NamedDimsIndices) = named(dename(only(axes(I))), I)
230235

236+
nameddimsarraytype(a::AbstractNamedDimsArray) = nameddimsarraytype(typeof(a))
237+
nameddimsarraytype(a::Type{<:AbstractNamedDimsArray}) = unspecify_type_parameters(a)
238+
239+
function similar_nameddims(a::AbstractNamedDimsArray, elt::Type, inds)
240+
ax = to_nameddimsaxes(inds)
241+
return nameddimsarraytype(a)(similar(dename(a), elt, dename.(Tuple(ax))), name.(ax))
242+
end
243+
244+
function similar_nameddims(a::AbstractArray, elt::Type, inds)
245+
ax = to_nameddimsaxes(inds)
246+
return nameddims(similar(a, elt, dename.(Tuple(ax))), name.(ax))
247+
end
248+
249+
# Base.similar gets the eltype at compile time.
250+
function Base.similar(a::AbstractNamedDimsArray)
251+
return similar(a, eltype(a))
252+
end
253+
231254
function Base.similar(
232255
a::AbstractArray, elt::Type, inds::Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
233256
)
234-
ax = to_nameddimsaxes(inds)
235-
return nameddims(similar(unname(a), elt, dename.(ax)), name.(ax))
257+
return similar_nameddims(a, elt, inds)
258+
end
259+
260+
function Base.similar(a::AbstractArray, elt::Type, inds::NaiveOrderedSet)
261+
return similar_nameddims(a, elt, inds)
236262
end
237263

238264
function setnameddimsindices(a::AbstractNamedDimsArray, nameddimsindices)
239-
return nameddims(dename(a), nameddimsindices)
265+
return nameddimsarraytype(a)(dename(a), nameddimsindices)
240266
end
241267
function replacenameddimsindices(f, a::AbstractNamedDimsArray)
242268
return setnameddimsindices(a, replace(f, nameddimsindices(a)))
@@ -530,7 +556,7 @@ function Base.setindex!(
530556
Irest::NamedViewIndex...,
531557
)
532558
I = (I1, Irest...)
533-
setindex!(a, nameddims(value, I), I...)
559+
setindex!(a, nameddimsarraytype(a)(value, I), I...)
534560
return a
535561
end
536562
function Base.setindex!(
@@ -560,7 +586,7 @@ function aligndims(a::AbstractArray, dims)
560586
"Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)."
561587
),
562588
)
563-
return nameddims(permutedims(dename(a), perm), new_nameddimsindices)
589+
return nameddimsarraytype(a)(permutedims(dename(a), perm), new_nameddimsindices)
564590
end
565591

566592
function aligneddims(a::AbstractArray, dims)
@@ -572,7 +598,7 @@ function aligneddims(a::AbstractArray, dims)
572598
"Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)."
573599
),
574600
)
575-
return nameddims(PermutedDimsArray(dename(a), perm), new_nameddimsindices)
601+
return nameddimsarraytype(a)(PermutedDimsArray(dename(a), perm), new_nameddimsindices)
576602
end
577603

578604
# Convenient constructors
@@ -634,7 +660,7 @@ end
634660
for f in [:zeros, :ones]
635661
@eval begin
636662
function Base.$f(
637-
elt::Type{<:Number}, ax::Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
663+
elt::Type{<:Number}, inds::Tuple{NamedDimsIndices,Vararg{NamedDimsIndices}}
638664
)
639665
ax = to_nameddimsaxes(inds)
640666
a = $f(elt, dename.(ax))
@@ -760,6 +786,10 @@ end
760786
function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax)
761787
nameddimsindices = name.(ax)
762788
m′ = denamed(Mapped(bc), nameddimsindices)
789+
# TODO: Store the wrapper type in `AbstractNamedDimsArrayStyle` and use that
790+
# wrapper type rather than the generic `nameddims` constructor, which
791+
# can lose information.
792+
# Call it as `nameddimsarraytype(bc.style)`.
763793
return nameddims(similar(m′, elt, dename.(Tuple(ax))), nameddimsindices)
764794
end
765795

0 commit comments

Comments
 (0)