Skip to content

Commit 4c38a99

Browse files
authored
Don't forward dename to parent by default, more generic Array conversion and copyto! (#44)
1 parent 1fd398f commit 4c38a99

File tree

4 files changed

+73
-21
lines changed

4 files changed

+73
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NamedDimsArrays"
22
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.5"
4+
version = "0.5.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractnameddimsarray.jl

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,20 @@ DerivableInterfaces.interface(::Type{<:AbstractNamedDimsArray}) = NamedDimsArray
2020

2121
# Output the dimension names.
2222
nameddimsindices(a::AbstractArray) = throw(MethodError(nameddimsindices, Tuple{typeof(a)}))
23-
# Unwrapping the names
24-
Base.parent(a::AbstractNamedDimsArray) = throw(MethodError(parent, Tuple{typeof(a)}))
23+
# Unwrapping the names (`NamedDimsArrays.jl` interface).
24+
# TODO: Use `IsNamed` trait?
25+
dename(a::AbstractNamedDimsArray) = throw(MethodError(dename, Tuple{typeof(a)}))
26+
function dename(a::AbstractNamedDimsArray, nameddimsindices)
27+
return dename(aligndims(a, nameddimsindices))
28+
end
29+
function denamed(a::AbstractNamedDimsArray, nameddimsindices)
30+
return dename(aligneddims(a, nameddimsindices))
31+
end
32+
33+
unname(a::AbstractArray, nameddimsindices) = dename(a, nameddimsindices)
34+
unnamed(a::AbstractArray, nameddimsindices) = denamed(a, nameddimsindices)
35+
36+
isnamed(::Type{<:AbstractNamedDimsArray}) = true
2537

2638
nameddimsindices(a::AbstractArray, dim::Int) = nameddimsindices(a)[dim]
2739

@@ -109,21 +121,6 @@ function to_nameddimsindices(a::AbstractNamedDimsArray, dims)
109121
return map(dim -> to_dimname(a, dim), dims)
110122
end
111123

112-
# Unwrapping the names (`NamedDimsArrays.jl` interface).
113-
# TODO: Use `IsNamed` trait?
114-
dename(a::AbstractNamedDimsArray) = parent(a)
115-
function dename(a::AbstractNamedDimsArray, nameddimsindices)
116-
return dename(aligndims(a, nameddimsindices))
117-
end
118-
function denamed(a::AbstractNamedDimsArray, nameddimsindices)
119-
return dename(aligneddims(a, nameddimsindices))
120-
end
121-
122-
unname(a::AbstractArray, nameddimsindices) = dename(a, nameddimsindices)
123-
unnamed(a::AbstractArray, nameddimsindices) = denamed(a, nameddimsindices)
124-
125-
isnamed(::Type{<:AbstractNamedDimsArray}) = true
126-
127124
# TODO: Move to `utils.jl` file.
128125
# TODO: Use `Base.indexin`?
129126
function getperm(x, y; isequal=isequal)
@@ -148,6 +145,36 @@ function Base.copy(a::AbstractNamedDimsArray)
148145
return constructorof(typeof(a))(copy(dename(a)), nameddimsindices(a))
149146
end
150147

148+
function Base.copyto!(a_dest::AbstractNamedDimsArray, a_src::AbstractNamedDimsArray)
149+
a′_dest = dename(a_dest)
150+
# TODO: Use `denamed` to do the permutations lazily.
151+
a′_src = dename(a_src, nameddimsindices(a_dest))
152+
copyto!(a′_dest, a′_src)
153+
return a_dest
154+
end
155+
156+
# Conversion
157+
158+
# Copied from `Base` (defined in abstractarray.jl).
159+
@noinline _checkaxs(axd, axs) =
160+
axd == axs || throw(DimensionMismatch("axes must agree, got $axd and $axs"))
161+
function copyto_axcheck!(dest, src)
162+
_checkaxs(axes(dest), axes(src))
163+
return copyto!(dest, src)
164+
end
165+
166+
# These are defined since the Base versions assume the eltype and ndims are known
167+
# at compile time, which isn't true for ITensors.
168+
Base.Array(a::AbstractNamedDimsArray) = Array(dename(a))
169+
Base.Array{T}(a::AbstractNamedDimsArray) where {T} = Array{T}(dename(a))
170+
Base.Array{T,N}(a::AbstractNamedDimsArray) where {T,N} = Array{T,N}(dename(a))
171+
Base.AbstractArray{T}(a::AbstractNamedDimsArray) where {T} = AbstractArray{T,ndims(a)}(a)
172+
function Base.AbstractArray{T,N}(a::AbstractNamedDimsArray) where {T,N}
173+
dest = similar(a, T)
174+
copyto_axcheck!(dename(dest), dename(a))
175+
return dest
176+
end
177+
151178
const NamedDimsIndices = Union{
152179
AbstractNamedUnitRange{<:Integer},AbstractNamedArray{<:Integer}
153180
}
@@ -380,7 +407,7 @@ function Base.getindex(a::NamedDimsCartesianIndices{N}, I::Vararg{Int,N}) where
380407
end
381408

382409
nameddimsindices(I::NamedDimsCartesianIndices) = name.(I.indices)
383-
function Base.parent(I::NamedDimsCartesianIndices)
410+
function dename(I::NamedDimsCartesianIndices)
384411
return CartesianIndices(dename.(I.indices))
385412
end
386413

src/nameddimsarray.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ function NamedDimsArray(a::AbstractNamedDimsArray)
3131
end
3232

3333
# Minimal interface.
34-
nameddimsindices(a::NamedDimsArray) = a.nameddimsindices
35-
Base.parent(a::NamedDimsArray) = a.parent
34+
nameddimsindices(a::NamedDimsArray) = getfield(a, :nameddimsindices)
35+
Base.parent(a::NamedDimsArray) = getfield(a, :parent)
36+
dename(a::NamedDimsArray) = parent(a)
3637

3738
function TypeParameterAccessors.position(
3839
::Type{<:AbstractNamedDimsArray}, ::typeof(parenttype)

test/test_basics.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,30 @@ using Test: @test, @test_throws, @testset
7171
@test_throws ErrorException NamedDimsArray(randn(4), namedoneto.((2, 2), ("i", "j")))
7272
@test_throws ErrorException NamedDimsArray(randn(2, 2), namedoneto.((2, 3), ("i", "j")))
7373

74+
a = randn(elt, 3, 4)
75+
na = nameddimsarray(a, ("i", "j"))
76+
a′ = Array(na)
77+
@test eltype(a′) === elt
78+
@test a′ isa Matrix{elt}
79+
@test a′ == a
80+
81+
a = randn(elt, 3, 4)
82+
na = nameddimsarray(a, ("i", "j"))
83+
for a′ in (Array{Float32}(na), Matrix{Float32}(na))
84+
@test eltype(a′) === Float32
85+
@test a′ isa Matrix{Float32}
86+
@test a′ == Float32.(a)
87+
end
88+
89+
a = randn(elt, 2, 2, 2)
90+
na = nameddimsarray(a, ("i", "j", "k"))
91+
b = randn(elt, 2, 2, 2)
92+
nb = nameddimsarray(b, ("k", "i", "j"))
93+
copyto!(na, nb)
94+
@test na == nb
95+
@test dename(na) == dename(nb, ("i", "j", "k"))
96+
@test dename(na) == permutedims(dename(nb), (2, 3, 1))
97+
7498
a = randn(elt, 3, 4)
7599
na = nameddimsarray(a, ("i", "j"))
76100
i = namedoneto(3, "i")

0 commit comments

Comments
 (0)