11# Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl
2- # Like `PermutedDimsArrays` but recursive, similar to `Adjoint` and `Transpose`.
3- module RecursivePermutedDimsArrays
2+ # Like `PermutedDimsArrays` but singly nested, similar to `Adjoint` and `Transpose`
3+ # (though those are fully recursive).
4+ module NestedPermutedDimsArrays
45
56import Base: permutedims, permutedims!
6- export RecursivePermutedDimsArray
7+ export NestedPermutedDimsArray
78
89# Some day we will want storage-order-aware iteration, so put perm in the parameters
9- struct RecursivePermutedDimsArray {T,N,perm,iperm,AA<: AbstractArray } <: AbstractArray{T,N}
10+ struct NestedPermutedDimsArray {T,N,perm,iperm,AA<: AbstractArray } <: AbstractArray{T,N}
1011 parent:: AA
1112
12- function RecursivePermutedDimsArray {T,N,perm,iperm,AA} (
13+ function NestedPermutedDimsArray {T,N,perm,iperm,AA} (
1314 data:: AA
1415 ) where {T,N,perm,iperm,AA<: AbstractArray }
1516 (isa (perm, NTuple{N,Int}) && isa (iperm, NTuple{N,Int})) ||
@@ -23,7 +24,7 @@ struct RecursivePermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractA
2324end
2425
2526"""
26- RecursivePermutedDimsArray (A, perm) -> B
27+ NestedPermutedDimsArray (A, perm) -> B
2728
2829Given an AbstractArray `A`, create a view `B` such that the
2930dimensions appear to be permuted. Similar to `permutedims`, except
@@ -35,7 +36,7 @@ See also [`permutedims`](@ref), [`invperm`](@ref).
3536```jldoctest
3637julia> A = rand(3,5,4);
3738
38- julia> B = RecursivePermutedDimsArray (A, (3,1,2));
39+ julia> B = NestedPermutedDimsArray (A, (3,1,2));
3940
4041julia> size(B)
4142(4, 3, 5)
@@ -44,50 +45,44 @@ julia> B[3,1,2] == A[1,2,3]
4445true
4546```
4647"""
47- Base. @constprop :aggressive function RecursivePermutedDimsArray (
48+ Base. @constprop :aggressive function NestedPermutedDimsArray (
4849 data:: AbstractArray{T,N} , perm
4950) where {T,N}
5051 length (perm) == N ||
5152 throw (ArgumentError (string (perm, " is not a valid permutation of dimensions 1:" , N)))
5253 iperm = invperm (perm)
53- return RecursivePermutedDimsArray {
54- recursivepermuteddimsarraytype (T, perm),N,(perm... ,),(iperm... ,),typeof (data)
54+ return NestedPermutedDimsArray {
55+ maybe_permuteddimsarraytype (T, perm),N,(perm... ,),(iperm... ,),typeof (data)
5556 }(
5657 data
5758 )
5859end
5960
60- # Ideally we would use `Base.promote_op(recursivepermuteddimsarray, type, perm)` but
61- # that doesn't seem to preserve the `perm`/`iperm` type parameters.
62- function recursivepermuteddimsarraytype (type:: Type{<:AbstractArray{<:AbstractArray}} , perm)
63- return RecursivePermutedDimsArray{
64- recursivepermuteddimsarraytype (eltype (type), perm),ndims (type),perm,invperm (perm),type
65- }
66- end
67- function recursivepermuteddimsarraytype (type:: Type{<:AbstractArray} , perm)
61+ # Ideally would use `Base.promote_op(maybe_permuteddimsarraytype, type, perm)`
62+ # but it doesn't handle `perm` properly.
63+ function maybe_permuteddimsarraytype (type:: Type{<:AbstractArray} , perm)
6864 return PermutedDimsArray{eltype (type),ndims (type),perm,invperm (perm),type}
6965end
70- recursivepermuteddimsarraytype (type:: Type , perm) = type
66+ maybe_permuteddimsarraytype (type:: Type , perm) = type
7167
72- function recursivepermuteddimsarray (A:: AbstractArray{<:AbstractArray} , perm)
73- return RecursivePermutedDimsArray (A, perm)
68+ function maybe_permuteddimsarray (A:: AbstractArray , perm)
69+ return PermutedDimsArray (A, perm)
7470end
75- recursivepermuteddimsarray (A:: AbstractArray , perm) = PermutedDimsArray (A, perm)
7671# By default, assume scalar and don't permute.
77- recursivepermuteddimsarray (x, perm) = x
72+ maybe_permuteddimsarray (x, perm) = x
7873
79- Base. parent (A:: RecursivePermutedDimsArray ) = A. parent
80- function Base. size (A:: RecursivePermutedDimsArray {T,N,perm} ) where {T,N,perm}
74+ Base. parent (A:: NestedPermutedDimsArray ) = A. parent
75+ function Base. size (A:: NestedPermutedDimsArray {T,N,perm} ) where {T,N,perm}
8176 return genperm (size (parent (A)), perm)
8277end
83- function Base. axes (A:: RecursivePermutedDimsArray {T,N,perm} ) where {T,N,perm}
78+ function Base. axes (A:: NestedPermutedDimsArray {T,N,perm} ) where {T,N,perm}
8479 return genperm (axes (parent (A)), perm)
8580end
86- Base. has_offset_axes (A:: RecursivePermutedDimsArray ) = Base. has_offset_axes (A. parent)
87- function Base. similar (A:: RecursivePermutedDimsArray , T:: Type , dims:: Base.Dims )
81+ Base. has_offset_axes (A:: NestedPermutedDimsArray ) = Base. has_offset_axes (A. parent)
82+ function Base. similar (A:: NestedPermutedDimsArray , T:: Type , dims:: Base.Dims )
8883 return similar (parent (A), T, dims)
8984end
90- function Base. cconvert (:: Type{Ptr{T}} , A:: RecursivePermutedDimsArray {T} ) where {T}
85+ function Base. cconvert (:: Type{Ptr{T}} , A:: NestedPermutedDimsArray {T} ) where {T}
9186 return Base. cconvert (Ptr{T}, parent (A))
9287end
9388
9691# order than used by Julia. But for an array with unconventional
9792# storage order, a linear offset is ambiguous---is it a memory offset
9893# or a linear index?
99- function Base. pointer (A:: RecursivePermutedDimsArray , i:: Integer )
94+ function Base. pointer (A:: NestedPermutedDimsArray , i:: Integer )
10095 throw (
101- ArgumentError (
102- " pointer(A, i) is deliberately unsupported for RecursivePermutedDimsArray"
103- ),
96+ ArgumentError (" pointer(A, i) is deliberately unsupported for NestedPermutedDimsArray" )
10497 )
10598end
10699
107- function Base. strides (A:: RecursivePermutedDimsArray {T,N,perm} ) where {T,N,perm}
100+ function Base. strides (A:: NestedPermutedDimsArray {T,N,perm} ) where {T,N,perm}
108101 s = strides (parent (A))
109102 return ntuple (d -> s[perm[d]], Val (N))
110103end
111- function Base. elsize (
112- :: Type{<:RecursivePermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}}
113- ) where {P}
104+ function Base. elsize (:: Type{<:NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}} ) where {P}
114105 return Base. elsize (P)
115106end
116107
117108@inline function Base. getindex (
118- A:: RecursivePermutedDimsArray {T,N,perm,iperm} , I:: Vararg{Int,N}
109+ A:: NestedPermutedDimsArray {T,N,perm,iperm} , I:: Vararg{Int,N}
119110) where {T,N,perm,iperm}
120111 @boundscheck checkbounds (A, I... )
121- @inbounds val = recursivepermuteddimsarray (getindex (A. parent, genperm (I, iperm)... ), perm)
112+ @inbounds val = maybe_permuteddimsarray (getindex (A. parent, genperm (I, iperm)... ), perm)
122113 return val
123114end
124115@inline function Base. setindex! (
125- A:: RecursivePermutedDimsArray {T,N,perm,iperm} , val, I:: Vararg{Int,N}
116+ A:: NestedPermutedDimsArray {T,N,perm,iperm} , val, I:: Vararg{Int,N}
126117) where {T,N,perm,iperm}
127118 @boundscheck checkbounds (A, I... )
128- @inbounds setindex! (A. parent, recursivepermuteddimsarray (val, perm), genperm (I, iperm)... )
119+ @inbounds setindex! (A. parent, maybe_permuteddimsarray (val, perm), genperm (I, iperm)... )
129120 return val
130121end
131122
132123function Base. isassigned (
133- A:: RecursivePermutedDimsArray {T,N,perm,iperm} , I:: Vararg{Int,N}
124+ A:: NestedPermutedDimsArray {T,N,perm,iperm} , I:: Vararg{Int,N}
134125) where {T,N,perm,iperm}
135126 @boundscheck checkbounds (Bool, A, I... ) || return false
136127 @inbounds x = isassigned (A. parent, genperm (I, iperm)... )
@@ -141,14 +132,14 @@ end
141132@inline genperm (I, perm:: AbstractVector{Int} ) = genperm (I, (perm... ,))
142133
143134function Base. copyto! (
144- dest:: RecursivePermutedDimsArray {T,N} , src:: AbstractArray{T,N}
135+ dest:: NestedPermutedDimsArray {T,N} , src:: AbstractArray{T,N}
145136) where {T,N}
146137 checkbounds (dest, axes (src)... )
147138 return _copy! (dest, src)
148139end
149- Base. copyto! (dest:: RecursivePermutedDimsArray , src:: AbstractArray ) = _copy! (dest, src)
140+ Base. copyto! (dest:: NestedPermutedDimsArray , src:: AbstractArray ) = _copy! (dest, src)
150141
151- function _copy! (P:: RecursivePermutedDimsArray {T,N,perm} , src) where {T,N,perm}
142+ function _copy! (P:: NestedPermutedDimsArray {T,N,perm} , src) where {T,N,perm}
152143 # If dest/src are "close to dense," then it pays to be cache-friendly.
153144 # Determine the first permuted dimension
154145 d = 0 # d+1 will hold the first permuted dimension of src
@@ -168,7 +159,7 @@ function _copy!(P::RecursivePermutedDimsArray{T,N,perm}, src) where {T,N,perm}
168159end
169160
170161@noinline function _permutedims! (
171- P:: RecursivePermutedDimsArray , src, R1:: CartesianIndices{0} , R2, R3, ds, dp
162+ P:: NestedPermutedDimsArray , src, R1:: CartesianIndices{0} , R2, R3, ds, dp
172163)
173164 ip, is = axes (src, dp), axes (src, ds)
174165 for jo in first (ip): 8 : last (ip), io in first (is): 8 : last (is)
183174 return P
184175end
185176
186- @noinline function _permutedims! (P:: RecursivePermutedDimsArray , src, R1, R2, R3, ds, dp)
177+ @noinline function _permutedims! (P:: NestedPermutedDimsArray , src, R1, R2, R3, ds, dp)
187178 ip, is = axes (src, dp), axes (src, ds)
188179 for jo in first (ip): 8 : last (ip), io in first (is): 8 : last (is)
189180 for I3 in R3, I2 in R2
@@ -210,49 +201,42 @@ const CommutativeOps = Union{
210201}
211202
212203function Base. _mapreduce_dim (
213- f,
214- op:: CommutativeOps ,
215- init:: Base._InitialValue ,
216- A:: RecursivePermutedDimsArray ,
217- dims:: Colon ,
204+ f, op:: CommutativeOps , init:: Base._InitialValue , A:: NestedPermutedDimsArray , dims:: Colon
218205)
219206 return Base. _mapreduce_dim (f, op, init, parent (A), dims)
220207end
221208function Base. _mapreduce_dim (
222209 f:: typeof (identity),
223210 op:: Union{typeof(Base.mul_prod),typeof(*)} ,
224211 init:: Base._InitialValue ,
225- A:: RecursivePermutedDimsArray {<:Union{Real,Complex}} ,
212+ A:: NestedPermutedDimsArray {<:Union{Real,Complex}} ,
226213 dims:: Colon ,
227214)
228215 return Base. _mapreduce_dim (f, op, init, parent (A), dims)
229216end
230217
231218function Base. mapreducedim! (
232- f,
233- op:: CommutativeOps ,
234- B:: AbstractArray{T,N} ,
235- A:: RecursivePermutedDimsArray{S,N,perm,iperm} ,
219+ f, op:: CommutativeOps , B:: AbstractArray{T,N} , A:: NestedPermutedDimsArray{S,N,perm,iperm}
236220) where {T,S,N,perm,iperm}
237- C = RecursivePermutedDimsArray {T,N,iperm,perm,typeof(B)} (B) # make the inverse permutation for the output
221+ C = NestedPermutedDimsArray {T,N,iperm,perm,typeof(B)} (B) # make the inverse permutation for the output
238222 Base. mapreducedim! (f, op, C, parent (A))
239223 return B
240224end
241225function Base. mapreducedim! (
242226 f:: typeof (identity),
243227 op:: Union{typeof(Base.mul_prod),typeof(*)} ,
244228 B:: AbstractArray{T,N} ,
245- A:: RecursivePermutedDimsArray {<:Union{Real,Complex},N,perm,iperm} ,
229+ A:: NestedPermutedDimsArray {<:Union{Real,Complex},N,perm,iperm} ,
246230) where {T,N,perm,iperm}
247- C = RecursivePermutedDimsArray {T,N,iperm,perm,typeof(B)} (B) # make the inverse permutation for the output
231+ C = NestedPermutedDimsArray {T,N,iperm,perm,typeof(B)} (B) # make the inverse permutation for the output
248232 Base. mapreducedim! (f, op, C, parent (A))
249233 return B
250234end
251235
252236function Base. showarg (
253- io:: IO , A:: RecursivePermutedDimsArray {T,N,perm} , toplevel
237+ io:: IO , A:: NestedPermutedDimsArray {T,N,perm} , toplevel
254238) where {T,N,perm}
255- print (io, " RecursivePermutedDimsArray (" )
239+ print (io, " NestedPermutedDimsArray (" )
256240 Base. showarg (io, parent (A), false )
257241 print (io, " , " , perm, ' )' )
258242 toplevel && print (io, " with eltype " , eltype (A))
0 commit comments