Skip to content

Commit 446a7f0

Browse files
committed
[NDTensors] Introduce RecursivePermutedDimsArrays submodule
1 parent 7faad33 commit 446a7f0

File tree

6 files changed

+290
-1
lines changed

6 files changed

+290
-1
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.67"
4+
version = "0.3.68"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/imports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ for lib in [
4040
:SparseArrayInterface,
4141
:SparseArrayDOKs,
4242
:DiagonalArrays,
43+
:RecursivePermutedDimsArrays,
4344
:BlockSparseArrays,
4445
:NamedDimsArrays,
4546
:SmallVectors,
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl
2+
# Like `PermutedDimsArrays` but recursive, similar to `Adjoint` and `Transpose`.
3+
module RecursivePermutedDimsArrays
4+
5+
import Base: permutedims, permutedims!
6+
export RecursivePermutedDimsArray
7+
8+
# Some day we will want storage-order-aware iteration, so put perm in the parameters
9+
struct RecursivePermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractArray{T,N}
10+
parent::AA
11+
12+
function RecursivePermutedDimsArray{T,N,perm,iperm,AA}(
13+
data::AA
14+
) where {T,N,perm,iperm,AA<:AbstractArray}
15+
(isa(perm, NTuple{N,Int}) && isa(iperm, NTuple{N,Int})) ||
16+
error("perm and iperm must both be NTuple{$N,Int}")
17+
isperm(perm) ||
18+
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
19+
all(d -> iperm[perm[d]] == d, 1:N) ||
20+
throw(ArgumentError(string(perm, " and ", iperm, " must be inverses")))
21+
return new(data)
22+
end
23+
end
24+
25+
"""
26+
RecursivePermutedDimsArray(A, perm) -> B
27+
28+
Given an AbstractArray `A`, create a view `B` such that the
29+
dimensions appear to be permuted. Similar to `permutedims`, except
30+
that no copying occurs (`B` shares storage with `A`).
31+
32+
See also [`permutedims`](@ref), [`invperm`](@ref).
33+
34+
# Examples
35+
```jldoctest
36+
julia> A = rand(3,5,4);
37+
38+
julia> B = RecursivePermutedDimsArray(A, (3,1,2));
39+
40+
julia> size(B)
41+
(4, 3, 5)
42+
43+
julia> B[3,1,2] == A[1,2,3]
44+
true
45+
```
46+
"""
47+
Base.@constprop :aggressive function RecursivePermutedDimsArray(
48+
data::AbstractArray{T,N}, perm
49+
) where {T,N}
50+
length(perm) == N ||
51+
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
52+
iperm = invperm(perm)
53+
return RecursivePermutedDimsArray{
54+
recursivepermuteddimsarraytype(T, perm),N,(perm...,),(iperm...,),typeof(data)
55+
}(
56+
data
57+
)
58+
end
59+
60+
# Ideally we would use `Base.promote_op(recursivepermuteddimsarray, type, perm)` but
61+
# that doesn't seem to preserve the `perm`/`iperm` type parameters.
62+
function recursivepermuteddimsarraytype(type::Type{<:AbstractArray{<:AbstractArray}}, perm)
63+
return RecursivePermutedDimsArray{
64+
recursivepermuteddimsarraytype(eltype(type), perm),ndims(type),perm,invperm(perm),type
65+
}
66+
end
67+
function recursivepermuteddimsarraytype(type::Type{<:AbstractArray}, perm)
68+
return PermutedDimsArray{eltype(type),ndims(type),perm,invperm(perm),type}
69+
end
70+
recursivepermuteddimsarraytype(type::Type, perm) = type
71+
72+
function recursivepermuteddimsarray(A::AbstractArray{<:AbstractArray}, perm)
73+
return RecursivePermutedDimsArray(A, perm)
74+
end
75+
recursivepermuteddimsarray(A::AbstractArray, perm) = PermutedDimsArray(A, perm)
76+
# By default, assume scalar and don't permute.
77+
recursivepermuteddimsarray(x, perm) = x
78+
79+
Base.parent(A::RecursivePermutedDimsArray) = A.parent
80+
function Base.size(A::RecursivePermutedDimsArray{T,N,perm}) where {T,N,perm}
81+
return genperm(size(parent(A)), perm)
82+
end
83+
function Base.axes(A::RecursivePermutedDimsArray{T,N,perm}) where {T,N,perm}
84+
return genperm(axes(parent(A)), perm)
85+
end
86+
Base.has_offset_axes(A::RecursivePermutedDimsArray) = Base.has_offset_axes(A.parent)
87+
function Base.similar(A::RecursivePermutedDimsArray, T::Type, dims::Base.Dims)
88+
return similar(parent(A), T, dims)
89+
end
90+
function Base.cconvert(::Type{Ptr{T}}, A::RecursivePermutedDimsArray{T}) where {T}
91+
return Base.cconvert(Ptr{T}, parent(A))
92+
end
93+
94+
# It's OK to return a pointer to the first element, and indeed quite
95+
# useful for wrapping C routines that require a different storage
96+
# order than used by Julia. But for an array with unconventional
97+
# storage order, a linear offset is ambiguous---is it a memory offset
98+
# or a linear index?
99+
function Base.pointer(A::RecursivePermutedDimsArray, i::Integer)
100+
throw(
101+
ArgumentError(
102+
"pointer(A, i) is deliberately unsupported for RecursivePermutedDimsArray"
103+
),
104+
)
105+
end
106+
107+
function Base.strides(A::RecursivePermutedDimsArray{T,N,perm}) where {T,N,perm}
108+
s = strides(parent(A))
109+
return ntuple(d -> s[perm[d]], Val(N))
110+
end
111+
function Base.elsize(
112+
::Type{<:RecursivePermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}}
113+
) where {P}
114+
return Base.elsize(P)
115+
end
116+
117+
@inline function Base.getindex(
118+
A::RecursivePermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
119+
) where {T,N,perm,iperm}
120+
@boundscheck checkbounds(A, I...)
121+
@inbounds val = recursivepermuteddimsarray(getindex(A.parent, genperm(I, iperm)...), perm)
122+
return val
123+
end
124+
@inline function Base.setindex!(
125+
A::RecursivePermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
126+
) where {T,N,perm,iperm}
127+
@boundscheck checkbounds(A, I...)
128+
@inbounds setindex!(A.parent, recursivepermuteddimsarray(val, perm), genperm(I, iperm)...)
129+
return val
130+
end
131+
132+
function Base.isassigned(
133+
A::RecursivePermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
134+
) where {T,N,perm,iperm}
135+
@boundscheck checkbounds(Bool, A, I...) || return false
136+
@inbounds x = isassigned(A.parent, genperm(I, iperm)...)
137+
return x
138+
end
139+
140+
@inline genperm(I::NTuple{N,Any}, perm::Dims{N}) where {N} = ntuple(d -> I[perm[d]], Val(N))
141+
@inline genperm(I, perm::AbstractVector{Int}) = genperm(I, (perm...,))
142+
143+
function Base.copyto!(
144+
dest::RecursivePermutedDimsArray{T,N}, src::AbstractArray{T,N}
145+
) where {T,N}
146+
checkbounds(dest, axes(src)...)
147+
return _copy!(dest, src)
148+
end
149+
Base.copyto!(dest::RecursivePermutedDimsArray, src::AbstractArray) = _copy!(dest, src)
150+
151+
function _copy!(P::RecursivePermutedDimsArray{T,N,perm}, src) where {T,N,perm}
152+
# If dest/src are "close to dense," then it pays to be cache-friendly.
153+
# Determine the first permuted dimension
154+
d = 0 # d+1 will hold the first permuted dimension of src
155+
while d < ndims(src) && perm[d + 1] == d + 1
156+
d += 1
157+
end
158+
if d == ndims(src)
159+
copyto!(parent(P), src) # it's not permuted
160+
else
161+
R1 = CartesianIndices(axes(src)[1:d])
162+
d1 = findfirst(isequal(d + 1), perm)::Int # first permuted dim of dest
163+
R2 = CartesianIndices(axes(src)[(d + 2):(d1 - 1)])
164+
R3 = CartesianIndices(axes(src)[(d1 + 1):end])
165+
_permutedims!(P, src, R1, R2, R3, d + 1, d1)
166+
end
167+
return P
168+
end
169+
170+
@noinline function _permutedims!(
171+
P::RecursivePermutedDimsArray, src, R1::CartesianIndices{0}, R2, R3, ds, dp
172+
)
173+
ip, is = axes(src, dp), axes(src, ds)
174+
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
175+
for I3 in R3, I2 in R2
176+
for j in jo:min(jo + 7, last(ip))
177+
for i in io:min(io + 7, last(is))
178+
@inbounds P[i, I2, j, I3] = src[i, I2, j, I3]
179+
end
180+
end
181+
end
182+
end
183+
return P
184+
end
185+
186+
@noinline function _permutedims!(P::RecursivePermutedDimsArray, src, R1, R2, R3, ds, dp)
187+
ip, is = axes(src, dp), axes(src, ds)
188+
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
189+
for I3 in R3, I2 in R2
190+
for j in jo:min(jo + 7, last(ip))
191+
for i in io:min(io + 7, last(is))
192+
for I1 in R1
193+
@inbounds P[I1, i, I2, j, I3] = src[I1, i, I2, j, I3]
194+
end
195+
end
196+
end
197+
end
198+
end
199+
return P
200+
end
201+
202+
const CommutativeOps = Union{
203+
typeof(+),
204+
typeof(Base.add_sum),
205+
typeof(min),
206+
typeof(max),
207+
typeof(Base._extrema_rf),
208+
typeof(|),
209+
typeof(&),
210+
}
211+
212+
function Base._mapreduce_dim(
213+
f,
214+
op::CommutativeOps,
215+
init::Base._InitialValue,
216+
A::RecursivePermutedDimsArray,
217+
dims::Colon,
218+
)
219+
return Base._mapreduce_dim(f, op, init, parent(A), dims)
220+
end
221+
function Base._mapreduce_dim(
222+
f::typeof(identity),
223+
op::Union{typeof(Base.mul_prod),typeof(*)},
224+
init::Base._InitialValue,
225+
A::RecursivePermutedDimsArray{<:Union{Real,Complex}},
226+
dims::Colon,
227+
)
228+
return Base._mapreduce_dim(f, op, init, parent(A), dims)
229+
end
230+
231+
function Base.mapreducedim!(
232+
f,
233+
op::CommutativeOps,
234+
B::AbstractArray{T,N},
235+
A::RecursivePermutedDimsArray{S,N,perm,iperm},
236+
) where {T,S,N,perm,iperm}
237+
C = RecursivePermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
238+
Base.mapreducedim!(f, op, C, parent(A))
239+
return B
240+
end
241+
function Base.mapreducedim!(
242+
f::typeof(identity),
243+
op::Union{typeof(Base.mul_prod),typeof(*)},
244+
B::AbstractArray{T,N},
245+
A::RecursivePermutedDimsArray{<:Union{Real,Complex},N,perm,iperm},
246+
) where {T,N,perm,iperm}
247+
C = RecursivePermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
248+
Base.mapreducedim!(f, op, C, parent(A))
249+
return B
250+
end
251+
252+
function Base.showarg(
253+
io::IO, A::RecursivePermutedDimsArray{T,N,perm}, toplevel
254+
) where {T,N,perm}
255+
print(io, "RecursivePermutedDimsArray(")
256+
Base.showarg(io, parent(A), false)
257+
print(io, ", ", perm, ')')
258+
toplevel && print(io, " with eltype ", eltype(A))
259+
return nothing
260+
end
261+
262+
end
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
@eval module $(gensym())
2+
using NDTensors.RecursivePermutedDimsArrays: RecursivePermutedDimsArray
3+
using Test: @test, @testset
4+
@testset "RecursivePermutedDimsArrays" for elt in (
5+
Float32, Float64, Complex{Float32}, Complex{Float64}
6+
)
7+
a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4)))
8+
perm = (3, 2, 1)
9+
p = RecursivePermutedDimsArray(a, perm)
10+
T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)}
11+
@test typeof(p) === RecursivePermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
12+
@test size(p) == (4, 3, 2)
13+
@test eltype(p) === T
14+
for I in eachindex(p)
15+
@test size(p[I]) == (4, 3, 2)
16+
@test p[I] == permutedims(a[CartesianIndex(reverse(Tuple(I)))], perm)
17+
end
18+
x = randn(elt, 4, 3, 2)
19+
p[3, 2, 1] = x
20+
@test p[3, 2, 1] == x
21+
@test a[1, 2, 3] == permutedims(x, perm)
22+
end
23+
end

NDTensors/test/lib/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using Test: @testset
1515
"LabelledNumbers",
1616
"MetalExtensions",
1717
"NamedDimsArrays",
18+
"RecursivePermutedDimsArrays",
1819
"SmallVectors",
1920
"SortedSets",
2021
"SparseArrayDOKs",

0 commit comments

Comments
 (0)