Skip to content

Commit 3594216

Browse files
authored
[NDTensors] Introduce NestedPermutedDimsArrays submodule (#1589)
1 parent dbec36b commit 3594216

File tree

5 files changed

+264
-0
lines changed

5 files changed

+264
-0
lines changed

NDTensors/src/imports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ for lib in [
3737
:GradedAxes,
3838
:SymmetrySectors,
3939
:TensorAlgebra,
40+
:NestedPermutedDimsArrays,
4041
:SparseArrayInterface,
4142
:SparseArrayDOKs,
4243
:DiagonalArrays,
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl
2+
# Like `PermutedDimsArrays` but singly nested, similar to `Adjoint` and `Transpose`
3+
# (though those are fully recursive).
4+
module NestedPermutedDimsArrays
5+
6+
import Base: permutedims, permutedims!
7+
export NestedPermutedDimsArray
8+
9+
# Some day we will want storage-order-aware iteration, so put perm in the parameters
10+
struct NestedPermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractArray{T,N}
11+
parent::AA
12+
13+
function NestedPermutedDimsArray{T,N,perm,iperm,AA}(
14+
data::AA
15+
) where {T,N,perm,iperm,AA<:AbstractArray}
16+
(isa(perm, NTuple{N,Int}) && isa(iperm, NTuple{N,Int})) ||
17+
error("perm and iperm must both be NTuple{$N,Int}")
18+
isperm(perm) ||
19+
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
20+
all(d -> iperm[perm[d]] == d, 1:N) ||
21+
throw(ArgumentError(string(perm, " and ", iperm, " must be inverses")))
22+
return new(data)
23+
end
24+
end
25+
26+
"""
27+
NestedPermutedDimsArray(A, perm) -> B
28+
29+
Given an AbstractArray `A`, create a view `B` such that the
30+
dimensions appear to be permuted. Similar to `permutedims`, except
31+
that no copying occurs (`B` shares storage with `A`).
32+
33+
See also [`permutedims`](@ref), [`invperm`](@ref).
34+
35+
# Examples
36+
```jldoctest
37+
julia> A = rand(3,5,4);
38+
39+
julia> B = NestedPermutedDimsArray(A, (3,1,2));
40+
41+
julia> size(B)
42+
(4, 3, 5)
43+
44+
julia> B[3,1,2] == A[1,2,3]
45+
true
46+
```
47+
"""
48+
Base.@constprop :aggressive function NestedPermutedDimsArray(
49+
data::AbstractArray{T,N}, perm
50+
) where {T,N}
51+
length(perm) == N ||
52+
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
53+
iperm = invperm(perm)
54+
return NestedPermutedDimsArray{
55+
PermutedDimsArray{eltype(T),N,(perm...,),(iperm...,),T},
56+
N,
57+
(perm...,),
58+
(iperm...,),
59+
typeof(data),
60+
}(
61+
data
62+
)
63+
end
64+
65+
Base.parent(A::NestedPermutedDimsArray) = A.parent
66+
function Base.size(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
67+
return genperm(size(parent(A)), perm)
68+
end
69+
function Base.axes(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
70+
return genperm(axes(parent(A)), perm)
71+
end
72+
Base.has_offset_axes(A::NestedPermutedDimsArray) = Base.has_offset_axes(A.parent)
73+
function Base.similar(A::NestedPermutedDimsArray, T::Type, dims::Base.Dims)
74+
return similar(parent(A), T, dims)
75+
end
76+
function Base.cconvert(::Type{Ptr{T}}, A::NestedPermutedDimsArray{T}) where {T}
77+
return Base.cconvert(Ptr{T}, parent(A))
78+
end
79+
80+
# It's OK to return a pointer to the first element, and indeed quite
81+
# useful for wrapping C routines that require a different storage
82+
# order than used by Julia. But for an array with unconventional
83+
# storage order, a linear offset is ambiguous---is it a memory offset
84+
# or a linear index?
85+
function Base.pointer(A::NestedPermutedDimsArray, i::Integer)
86+
throw(
87+
ArgumentError("pointer(A, i) is deliberately unsupported for NestedPermutedDimsArray")
88+
)
89+
end
90+
91+
function Base.strides(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
92+
s = strides(parent(A))
93+
return ntuple(d -> s[perm[d]], Val(N))
94+
end
95+
function Base.elsize(::Type{<:NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}}) where {P}
96+
return Base.elsize(P)
97+
end
98+
99+
@inline function Base.getindex(
100+
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
101+
) where {T,N,perm,iperm}
102+
@boundscheck checkbounds(A, I...)
103+
@inbounds val = PermutedDimsArray(getindex(A.parent, genperm(I, iperm)...), perm)
104+
return val
105+
end
106+
@inline function Base.setindex!(
107+
A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
108+
) where {T,N,perm,iperm}
109+
@boundscheck checkbounds(A, I...)
110+
@inbounds setindex!(A.parent, PermutedDimsArray(val, perm), genperm(I, iperm)...)
111+
return val
112+
end
113+
114+
function Base.isassigned(
115+
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
116+
) where {T,N,perm,iperm}
117+
@boundscheck checkbounds(Bool, A, I...) || return false
118+
@inbounds x = isassigned(A.parent, genperm(I, iperm)...)
119+
return x
120+
end
121+
122+
@inline genperm(I::NTuple{N,Any}, perm::Dims{N}) where {N} = ntuple(d -> I[perm[d]], Val(N))
123+
@inline genperm(I, perm::AbstractVector{Int}) = genperm(I, (perm...,))
124+
125+
function Base.copyto!(
126+
dest::NestedPermutedDimsArray{T,N}, src::AbstractArray{T,N}
127+
) where {T,N}
128+
checkbounds(dest, axes(src)...)
129+
return _copy!(dest, src)
130+
end
131+
Base.copyto!(dest::NestedPermutedDimsArray, src::AbstractArray) = _copy!(dest, src)
132+
133+
function _copy!(P::NestedPermutedDimsArray{T,N,perm}, src) where {T,N,perm}
134+
# If dest/src are "close to dense," then it pays to be cache-friendly.
135+
# Determine the first permuted dimension
136+
d = 0 # d+1 will hold the first permuted dimension of src
137+
while d < ndims(src) && perm[d + 1] == d + 1
138+
d += 1
139+
end
140+
if d == ndims(src)
141+
copyto!(parent(P), src) # it's not permuted
142+
else
143+
R1 = CartesianIndices(axes(src)[1:d])
144+
d1 = findfirst(isequal(d + 1), perm)::Int # first permuted dim of dest
145+
R2 = CartesianIndices(axes(src)[(d + 2):(d1 - 1)])
146+
R3 = CartesianIndices(axes(src)[(d1 + 1):end])
147+
_permutedims!(P, src, R1, R2, R3, d + 1, d1)
148+
end
149+
return P
150+
end
151+
152+
@noinline function _permutedims!(
153+
P::NestedPermutedDimsArray, src, R1::CartesianIndices{0}, R2, R3, ds, dp
154+
)
155+
ip, is = axes(src, dp), axes(src, ds)
156+
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
157+
for I3 in R3, I2 in R2
158+
for j in jo:min(jo + 7, last(ip))
159+
for i in io:min(io + 7, last(is))
160+
@inbounds P[i, I2, j, I3] = src[i, I2, j, I3]
161+
end
162+
end
163+
end
164+
end
165+
return P
166+
end
167+
168+
@noinline function _permutedims!(P::NestedPermutedDimsArray, src, R1, R2, R3, ds, dp)
169+
ip, is = axes(src, dp), axes(src, ds)
170+
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
171+
for I3 in R3, I2 in R2
172+
for j in jo:min(jo + 7, last(ip))
173+
for i in io:min(io + 7, last(is))
174+
for I1 in R1
175+
@inbounds P[I1, i, I2, j, I3] = src[I1, i, I2, j, I3]
176+
end
177+
end
178+
end
179+
end
180+
end
181+
return P
182+
end
183+
184+
const CommutativeOps = Union{
185+
typeof(+),
186+
typeof(Base.add_sum),
187+
typeof(min),
188+
typeof(max),
189+
typeof(Base._extrema_rf),
190+
typeof(|),
191+
typeof(&),
192+
}
193+
194+
function Base._mapreduce_dim(
195+
f, op::CommutativeOps, init::Base._InitialValue, A::NestedPermutedDimsArray, dims::Colon
196+
)
197+
return Base._mapreduce_dim(f, op, init, parent(A), dims)
198+
end
199+
function Base._mapreduce_dim(
200+
f::typeof(identity),
201+
op::Union{typeof(Base.mul_prod),typeof(*)},
202+
init::Base._InitialValue,
203+
A::NestedPermutedDimsArray{<:Union{Real,Complex}},
204+
dims::Colon,
205+
)
206+
return Base._mapreduce_dim(f, op, init, parent(A), dims)
207+
end
208+
209+
function Base.mapreducedim!(
210+
f, op::CommutativeOps, B::AbstractArray{T,N}, A::NestedPermutedDimsArray{S,N,perm,iperm}
211+
) where {T,S,N,perm,iperm}
212+
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
213+
Base.mapreducedim!(f, op, C, parent(A))
214+
return B
215+
end
216+
function Base.mapreducedim!(
217+
f::typeof(identity),
218+
op::Union{typeof(Base.mul_prod),typeof(*)},
219+
B::AbstractArray{T,N},
220+
A::NestedPermutedDimsArray{<:Union{Real,Complex},N,perm,iperm},
221+
) where {T,N,perm,iperm}
222+
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
223+
Base.mapreducedim!(f, op, C, parent(A))
224+
return B
225+
end
226+
227+
function Base.showarg(
228+
io::IO, A::NestedPermutedDimsArray{T,N,perm}, toplevel
229+
) where {T,N,perm}
230+
print(io, "NestedPermutedDimsArray(")
231+
Base.showarg(io, parent(A), false)
232+
print(io, ", ", perm, ')')
233+
toplevel && print(io, " with eltype ", eltype(A))
234+
return nothing
235+
end
236+
237+
end
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@eval module $(gensym())
2+
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
3+
using Test: @test, @testset
4+
@testset "NestedPermutedDimsArrays" for elt in (
5+
Float32, Float64, Complex{Float32}, Complex{Float64}
6+
)
7+
a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4)))
8+
perm = (3, 2, 1)
9+
p = NestedPermutedDimsArray(a, perm)
10+
T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)}
11+
@test typeof(p) === NestedPermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
12+
@test size(p) == (4, 3, 2)
13+
@test eltype(p) === T
14+
for I in eachindex(p)
15+
@test size(p[I]) == (4, 3, 2)
16+
@test p[I] == permutedims(a[CartesianIndex(reverse(Tuple(I)))], perm)
17+
end
18+
x = randn(elt, 4, 3, 2)
19+
p[3, 2, 1] = x
20+
@test p[3, 2, 1] == x
21+
@test a[1, 2, 3] == permutedims(x, perm)
22+
end
23+
end

NDTensors/test/lib/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using Test: @testset
1515
"LabelledNumbers",
1616
"MetalExtensions",
1717
"NamedDimsArrays",
18+
"NestedPermutedDimsArrays",
1819
"SmallVectors",
1920
"SortedSets",
2021
"SparseArrayDOKs",

0 commit comments

Comments
 (0)