Skip to content

Commit 2150d93

Browse files
committed
Rename RecursivePermutedDimsArrays to NestedPermutedDimsArrays, make singly nested
1 parent 446a7f0 commit 2150d93

File tree

5 files changed

+52
-68
lines changed

5 files changed

+52
-68
lines changed

NDTensors/src/imports.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ for lib in [
4040
:SparseArrayInterface,
4141
:SparseArrayDOKs,
4242
:DiagonalArrays,
43-
:RecursivePermutedDimsArrays,
43+
:NestedPermutedDimsArrays,
4444
:BlockSparseArrays,
4545
:NamedDimsArrays,
4646
:SmallVectors,

NDTensors/src/lib/RecursivePermutedDimsArrays/src/RecursivePermutedDimsArrays.jl renamed to NDTensors/src/lib/NestedPermutedDimsArrays/src/NestedPermutedDimsArrays.jl

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl
2-
# Like `PermutedDimsArrays` but recursive, similar to `Adjoint` and `Transpose`.
3-
module RecursivePermutedDimsArrays
2+
# Like `PermutedDimsArrays` but singly nested, similar to `Adjoint` and `Transpose`
3+
# (though those are fully recursive).
4+
module NestedPermutedDimsArrays
45

56
import Base: permutedims, permutedims!
6-
export RecursivePermutedDimsArray
7+
export NestedPermutedDimsArray
78

89
# Some day we will want storage-order-aware iteration, so put perm in the parameters
9-
struct RecursivePermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractArray{T,N}
10+
struct NestedPermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractArray{T,N}
1011
parent::AA
1112

12-
function RecursivePermutedDimsArray{T,N,perm,iperm,AA}(
13+
function NestedPermutedDimsArray{T,N,perm,iperm,AA}(
1314
data::AA
1415
) where {T,N,perm,iperm,AA<:AbstractArray}
1516
(isa(perm, NTuple{N,Int}) && isa(iperm, NTuple{N,Int})) ||
@@ -23,7 +24,7 @@ struct RecursivePermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractA
2324
end
2425

2526
"""
26-
RecursivePermutedDimsArray(A, perm) -> B
27+
NestedPermutedDimsArray(A, perm) -> B
2728
2829
Given an AbstractArray `A`, create a view `B` such that the
2930
dimensions appear to be permuted. Similar to `permutedims`, except
@@ -35,7 +36,7 @@ See also [`permutedims`](@ref), [`invperm`](@ref).
3536
```jldoctest
3637
julia> A = rand(3,5,4);
3738
38-
julia> B = RecursivePermutedDimsArray(A, (3,1,2));
39+
julia> B = NestedPermutedDimsArray(A, (3,1,2));
3940
4041
julia> size(B)
4142
(4, 3, 5)
@@ -44,50 +45,44 @@ julia> B[3,1,2] == A[1,2,3]
4445
true
4546
```
4647
"""
47-
Base.@constprop :aggressive function RecursivePermutedDimsArray(
48+
Base.@constprop :aggressive function NestedPermutedDimsArray(
4849
data::AbstractArray{T,N}, perm
4950
) where {T,N}
5051
length(perm) == N ||
5152
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
5253
iperm = invperm(perm)
53-
return RecursivePermutedDimsArray{
54-
recursivepermuteddimsarraytype(T, perm),N,(perm...,),(iperm...,),typeof(data)
54+
return NestedPermutedDimsArray{
55+
maybe_permuteddimsarraytype(T, perm),N,(perm...,),(iperm...,),typeof(data)
5556
}(
5657
data
5758
)
5859
end
5960

60-
# Ideally we would use `Base.promote_op(recursivepermuteddimsarray, type, perm)` but
61-
# that doesn't seem to preserve the `perm`/`iperm` type parameters.
62-
function recursivepermuteddimsarraytype(type::Type{<:AbstractArray{<:AbstractArray}}, perm)
63-
return RecursivePermutedDimsArray{
64-
recursivepermuteddimsarraytype(eltype(type), perm),ndims(type),perm,invperm(perm),type
65-
}
66-
end
67-
function recursivepermuteddimsarraytype(type::Type{<:AbstractArray}, perm)
61+
# Ideally would use `Base.promote_op(maybe_permuteddimsarraytype, type, perm)`
62+
# but it doesn't handle `perm` properly.
63+
function maybe_permuteddimsarraytype(type::Type{<:AbstractArray}, perm)
6864
return PermutedDimsArray{eltype(type),ndims(type),perm,invperm(perm),type}
6965
end
70-
recursivepermuteddimsarraytype(type::Type, perm) = type
66+
maybe_permuteddimsarraytype(type::Type, perm) = type
7167

72-
function recursivepermuteddimsarray(A::AbstractArray{<:AbstractArray}, perm)
73-
return RecursivePermutedDimsArray(A, perm)
68+
function maybe_permuteddimsarray(A::AbstractArray, perm)
69+
return PermutedDimsArray(A, perm)
7470
end
75-
recursivepermuteddimsarray(A::AbstractArray, perm) = PermutedDimsArray(A, perm)
7671
# By default, assume scalar and don't permute.
77-
recursivepermuteddimsarray(x, perm) = x
72+
maybe_permuteddimsarray(x, perm) = x
7873

79-
Base.parent(A::RecursivePermutedDimsArray) = A.parent
80-
function Base.size(A::RecursivePermutedDimsArray{T,N,perm}) where {T,N,perm}
74+
Base.parent(A::NestedPermutedDimsArray) = A.parent
75+
function Base.size(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
8176
return genperm(size(parent(A)), perm)
8277
end
83-
function Base.axes(A::RecursivePermutedDimsArray{T,N,perm}) where {T,N,perm}
78+
function Base.axes(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
8479
return genperm(axes(parent(A)), perm)
8580
end
86-
Base.has_offset_axes(A::RecursivePermutedDimsArray) = Base.has_offset_axes(A.parent)
87-
function Base.similar(A::RecursivePermutedDimsArray, T::Type, dims::Base.Dims)
81+
Base.has_offset_axes(A::NestedPermutedDimsArray) = Base.has_offset_axes(A.parent)
82+
function Base.similar(A::NestedPermutedDimsArray, T::Type, dims::Base.Dims)
8883
return similar(parent(A), T, dims)
8984
end
90-
function Base.cconvert(::Type{Ptr{T}}, A::RecursivePermutedDimsArray{T}) where {T}
85+
function Base.cconvert(::Type{Ptr{T}}, A::NestedPermutedDimsArray{T}) where {T}
9186
return Base.cconvert(Ptr{T}, parent(A))
9287
end
9388

@@ -96,41 +91,37 @@ end
9691
# order than used by Julia. But for an array with unconventional
9792
# storage order, a linear offset is ambiguous---is it a memory offset
9893
# or a linear index?
99-
function Base.pointer(A::RecursivePermutedDimsArray, i::Integer)
94+
function Base.pointer(A::NestedPermutedDimsArray, i::Integer)
10095
throw(
101-
ArgumentError(
102-
"pointer(A, i) is deliberately unsupported for RecursivePermutedDimsArray"
103-
),
96+
ArgumentError("pointer(A, i) is deliberately unsupported for NestedPermutedDimsArray")
10497
)
10598
end
10699

107-
function Base.strides(A::RecursivePermutedDimsArray{T,N,perm}) where {T,N,perm}
100+
function Base.strides(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
108101
s = strides(parent(A))
109102
return ntuple(d -> s[perm[d]], Val(N))
110103
end
111-
function Base.elsize(
112-
::Type{<:RecursivePermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}}
113-
) where {P}
104+
function Base.elsize(::Type{<:NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}}) where {P}
114105
return Base.elsize(P)
115106
end
116107

117108
@inline function Base.getindex(
118-
A::RecursivePermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
109+
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
119110
) where {T,N,perm,iperm}
120111
@boundscheck checkbounds(A, I...)
121-
@inbounds val = recursivepermuteddimsarray(getindex(A.parent, genperm(I, iperm)...), perm)
112+
@inbounds val = maybe_permuteddimsarray(getindex(A.parent, genperm(I, iperm)...), perm)
122113
return val
123114
end
124115
@inline function Base.setindex!(
125-
A::RecursivePermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
116+
A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
126117
) where {T,N,perm,iperm}
127118
@boundscheck checkbounds(A, I...)
128-
@inbounds setindex!(A.parent, recursivepermuteddimsarray(val, perm), genperm(I, iperm)...)
119+
@inbounds setindex!(A.parent, maybe_permuteddimsarray(val, perm), genperm(I, iperm)...)
129120
return val
130121
end
131122

132123
function Base.isassigned(
133-
A::RecursivePermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
124+
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
134125
) where {T,N,perm,iperm}
135126
@boundscheck checkbounds(Bool, A, I...) || return false
136127
@inbounds x = isassigned(A.parent, genperm(I, iperm)...)
@@ -141,14 +132,14 @@ end
141132
@inline genperm(I, perm::AbstractVector{Int}) = genperm(I, (perm...,))
142133

143134
function Base.copyto!(
144-
dest::RecursivePermutedDimsArray{T,N}, src::AbstractArray{T,N}
135+
dest::NestedPermutedDimsArray{T,N}, src::AbstractArray{T,N}
145136
) where {T,N}
146137
checkbounds(dest, axes(src)...)
147138
return _copy!(dest, src)
148139
end
149-
Base.copyto!(dest::RecursivePermutedDimsArray, src::AbstractArray) = _copy!(dest, src)
140+
Base.copyto!(dest::NestedPermutedDimsArray, src::AbstractArray) = _copy!(dest, src)
150141

151-
function _copy!(P::RecursivePermutedDimsArray{T,N,perm}, src) where {T,N,perm}
142+
function _copy!(P::NestedPermutedDimsArray{T,N,perm}, src) where {T,N,perm}
152143
# If dest/src are "close to dense," then it pays to be cache-friendly.
153144
# Determine the first permuted dimension
154145
d = 0 # d+1 will hold the first permuted dimension of src
@@ -168,7 +159,7 @@ function _copy!(P::RecursivePermutedDimsArray{T,N,perm}, src) where {T,N,perm}
168159
end
169160

170161
@noinline function _permutedims!(
171-
P::RecursivePermutedDimsArray, src, R1::CartesianIndices{0}, R2, R3, ds, dp
162+
P::NestedPermutedDimsArray, src, R1::CartesianIndices{0}, R2, R3, ds, dp
172163
)
173164
ip, is = axes(src, dp), axes(src, ds)
174165
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
@@ -183,7 +174,7 @@ end
183174
return P
184175
end
185176

186-
@noinline function _permutedims!(P::RecursivePermutedDimsArray, src, R1, R2, R3, ds, dp)
177+
@noinline function _permutedims!(P::NestedPermutedDimsArray, src, R1, R2, R3, ds, dp)
187178
ip, is = axes(src, dp), axes(src, ds)
188179
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
189180
for I3 in R3, I2 in R2
@@ -210,49 +201,42 @@ const CommutativeOps = Union{
210201
}
211202

212203
function Base._mapreduce_dim(
213-
f,
214-
op::CommutativeOps,
215-
init::Base._InitialValue,
216-
A::RecursivePermutedDimsArray,
217-
dims::Colon,
204+
f, op::CommutativeOps, init::Base._InitialValue, A::NestedPermutedDimsArray, dims::Colon
218205
)
219206
return Base._mapreduce_dim(f, op, init, parent(A), dims)
220207
end
221208
function Base._mapreduce_dim(
222209
f::typeof(identity),
223210
op::Union{typeof(Base.mul_prod),typeof(*)},
224211
init::Base._InitialValue,
225-
A::RecursivePermutedDimsArray{<:Union{Real,Complex}},
212+
A::NestedPermutedDimsArray{<:Union{Real,Complex}},
226213
dims::Colon,
227214
)
228215
return Base._mapreduce_dim(f, op, init, parent(A), dims)
229216
end
230217

231218
function Base.mapreducedim!(
232-
f,
233-
op::CommutativeOps,
234-
B::AbstractArray{T,N},
235-
A::RecursivePermutedDimsArray{S,N,perm,iperm},
219+
f, op::CommutativeOps, B::AbstractArray{T,N}, A::NestedPermutedDimsArray{S,N,perm,iperm}
236220
) where {T,S,N,perm,iperm}
237-
C = RecursivePermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
221+
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
238222
Base.mapreducedim!(f, op, C, parent(A))
239223
return B
240224
end
241225
function Base.mapreducedim!(
242226
f::typeof(identity),
243227
op::Union{typeof(Base.mul_prod),typeof(*)},
244228
B::AbstractArray{T,N},
245-
A::RecursivePermutedDimsArray{<:Union{Real,Complex},N,perm,iperm},
229+
A::NestedPermutedDimsArray{<:Union{Real,Complex},N,perm,iperm},
246230
) where {T,N,perm,iperm}
247-
C = RecursivePermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
231+
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
248232
Base.mapreducedim!(f, op, C, parent(A))
249233
return B
250234
end
251235

252236
function Base.showarg(
253-
io::IO, A::RecursivePermutedDimsArray{T,N,perm}, toplevel
237+
io::IO, A::NestedPermutedDimsArray{T,N,perm}, toplevel
254238
) where {T,N,perm}
255-
print(io, "RecursivePermutedDimsArray(")
239+
print(io, "NestedPermutedDimsArray(")
256240
Base.showarg(io, parent(A), false)
257241
print(io, ", ", perm, ')')
258242
toplevel && print(io, " with eltype ", eltype(A))

NDTensors/src/lib/RecursivePermutedDimsArrays/test/runtests.jl renamed to NDTensors/src/lib/NestedPermutedDimsArrays/test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
@eval module $(gensym())
2-
using NDTensors.RecursivePermutedDimsArrays: RecursivePermutedDimsArray
2+
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
33
using Test: @test, @testset
4-
@testset "RecursivePermutedDimsArrays" for elt in (
4+
@testset "NestedPermutedDimsArrays" for elt in (
55
Float32, Float64, Complex{Float32}, Complex{Float64}
66
)
77
a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4)))
88
perm = (3, 2, 1)
9-
p = RecursivePermutedDimsArray(a, perm)
9+
p = NestedPermutedDimsArray(a, perm)
1010
T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)}
11-
@test typeof(p) === RecursivePermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
11+
@test typeof(p) === NestedPermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
1212
@test size(p) == (4, 3, 2)
1313
@test eltype(p) === T
1414
for I in eachindex(p)

NDTensors/test/lib/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Test: @testset
1515
"LabelledNumbers",
1616
"MetalExtensions",
1717
"NamedDimsArrays",
18-
"RecursivePermutedDimsArrays",
18+
"NestedPermutedDimsArrays",
1919
"SmallVectors",
2020
"SortedSets",
2121
"SparseArrayDOKs",

0 commit comments

Comments
 (0)