|
| 1 | +# Mostly copied from https://github.com/JuliaLang/julia/blob/master/base/permuteddimsarray.jl |
| 2 | +# Like `PermutedDimsArrays` but singly nested, similar to `Adjoint` and `Transpose` |
| 3 | +# (though those are fully recursive). |
| 4 | +module NestedPermutedDimsArrays |
| 5 | + |
| 6 | +import Base: permutedims, permutedims! |
| 7 | +export NestedPermutedDimsArray |
| 8 | + |
| 9 | +# Some day we will want storage-order-aware iteration, so put perm in the parameters |
| 10 | +struct NestedPermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractArray{T,N} |
| 11 | + parent::AA |
| 12 | + |
| 13 | + function NestedPermutedDimsArray{T,N,perm,iperm,AA}( |
| 14 | + data::AA |
| 15 | + ) where {T,N,perm,iperm,AA<:AbstractArray} |
| 16 | + (isa(perm, NTuple{N,Int}) && isa(iperm, NTuple{N,Int})) || |
| 17 | + error("perm and iperm must both be NTuple{$N,Int}") |
| 18 | + isperm(perm) || |
| 19 | + throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N))) |
| 20 | + all(d -> iperm[perm[d]] == d, 1:N) || |
| 21 | + throw(ArgumentError(string(perm, " and ", iperm, " must be inverses"))) |
| 22 | + return new(data) |
| 23 | + end |
| 24 | +end |
| 25 | + |
| 26 | +""" |
| 27 | + NestedPermutedDimsArray(A, perm) -> B |
| 28 | +
|
| 29 | +Given an AbstractArray `A`, create a view `B` such that the |
| 30 | +dimensions appear to be permuted. Similar to `permutedims`, except |
| 31 | +that no copying occurs (`B` shares storage with `A`). |
| 32 | +
|
| 33 | +See also [`permutedims`](@ref), [`invperm`](@ref). |
| 34 | +
|
| 35 | +# Examples |
| 36 | +```jldoctest |
| 37 | +julia> A = rand(3,5,4); |
| 38 | +
|
| 39 | +julia> B = NestedPermutedDimsArray(A, (3,1,2)); |
| 40 | +
|
| 41 | +julia> size(B) |
| 42 | +(4, 3, 5) |
| 43 | +
|
| 44 | +julia> B[3,1,2] == A[1,2,3] |
| 45 | +true |
| 46 | +``` |
| 47 | +""" |
| 48 | +Base.@constprop :aggressive function NestedPermutedDimsArray( |
| 49 | + data::AbstractArray{T,N}, perm |
| 50 | +) where {T,N} |
| 51 | + length(perm) == N || |
| 52 | + throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N))) |
| 53 | + iperm = invperm(perm) |
| 54 | + return NestedPermutedDimsArray{ |
| 55 | + PermutedDimsArray{eltype(T),N,(perm...,),(iperm...,),T}, |
| 56 | + N, |
| 57 | + (perm...,), |
| 58 | + (iperm...,), |
| 59 | + typeof(data), |
| 60 | + }( |
| 61 | + data |
| 62 | + ) |
| 63 | +end |
| 64 | + |
| 65 | +Base.parent(A::NestedPermutedDimsArray) = A.parent |
| 66 | +function Base.size(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm} |
| 67 | + return genperm(size(parent(A)), perm) |
| 68 | +end |
| 69 | +function Base.axes(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm} |
| 70 | + return genperm(axes(parent(A)), perm) |
| 71 | +end |
| 72 | +Base.has_offset_axes(A::NestedPermutedDimsArray) = Base.has_offset_axes(A.parent) |
| 73 | +function Base.similar(A::NestedPermutedDimsArray, T::Type, dims::Base.Dims) |
| 74 | + return similar(parent(A), T, dims) |
| 75 | +end |
| 76 | +function Base.cconvert(::Type{Ptr{T}}, A::NestedPermutedDimsArray{T}) where {T} |
| 77 | + return Base.cconvert(Ptr{T}, parent(A)) |
| 78 | +end |
| 79 | + |
| 80 | +# It's OK to return a pointer to the first element, and indeed quite |
| 81 | +# useful for wrapping C routines that require a different storage |
| 82 | +# order than used by Julia. But for an array with unconventional |
| 83 | +# storage order, a linear offset is ambiguous---is it a memory offset |
| 84 | +# or a linear index? |
| 85 | +function Base.pointer(A::NestedPermutedDimsArray, i::Integer) |
| 86 | + throw( |
| 87 | + ArgumentError("pointer(A, i) is deliberately unsupported for NestedPermutedDimsArray") |
| 88 | + ) |
| 89 | +end |
| 90 | + |
| 91 | +function Base.strides(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm} |
| 92 | + s = strides(parent(A)) |
| 93 | + return ntuple(d -> s[perm[d]], Val(N)) |
| 94 | +end |
| 95 | +function Base.elsize(::Type{<:NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}}) where {P} |
| 96 | + return Base.elsize(P) |
| 97 | +end |
| 98 | + |
| 99 | +@inline function Base.getindex( |
| 100 | + A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N} |
| 101 | +) where {T,N,perm,iperm} |
| 102 | + @boundscheck checkbounds(A, I...) |
| 103 | + @inbounds val = PermutedDimsArray(getindex(A.parent, genperm(I, iperm)...), perm) |
| 104 | + return val |
| 105 | +end |
| 106 | +@inline function Base.setindex!( |
| 107 | + A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N} |
| 108 | +) where {T,N,perm,iperm} |
| 109 | + @boundscheck checkbounds(A, I...) |
| 110 | + @inbounds setindex!(A.parent, PermutedDimsArray(val, perm), genperm(I, iperm)...) |
| 111 | + return val |
| 112 | +end |
| 113 | + |
| 114 | +function Base.isassigned( |
| 115 | + A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N} |
| 116 | +) where {T,N,perm,iperm} |
| 117 | + @boundscheck checkbounds(Bool, A, I...) || return false |
| 118 | + @inbounds x = isassigned(A.parent, genperm(I, iperm)...) |
| 119 | + return x |
| 120 | +end |
| 121 | + |
| 122 | +@inline genperm(I::NTuple{N,Any}, perm::Dims{N}) where {N} = ntuple(d -> I[perm[d]], Val(N)) |
| 123 | +@inline genperm(I, perm::AbstractVector{Int}) = genperm(I, (perm...,)) |
| 124 | + |
| 125 | +function Base.copyto!( |
| 126 | + dest::NestedPermutedDimsArray{T,N}, src::AbstractArray{T,N} |
| 127 | +) where {T,N} |
| 128 | + checkbounds(dest, axes(src)...) |
| 129 | + return _copy!(dest, src) |
| 130 | +end |
| 131 | +Base.copyto!(dest::NestedPermutedDimsArray, src::AbstractArray) = _copy!(dest, src) |
| 132 | + |
| 133 | +function _copy!(P::NestedPermutedDimsArray{T,N,perm}, src) where {T,N,perm} |
| 134 | + # If dest/src are "close to dense," then it pays to be cache-friendly. |
| 135 | + # Determine the first permuted dimension |
| 136 | + d = 0 # d+1 will hold the first permuted dimension of src |
| 137 | + while d < ndims(src) && perm[d + 1] == d + 1 |
| 138 | + d += 1 |
| 139 | + end |
| 140 | + if d == ndims(src) |
| 141 | + copyto!(parent(P), src) # it's not permuted |
| 142 | + else |
| 143 | + R1 = CartesianIndices(axes(src)[1:d]) |
| 144 | + d1 = findfirst(isequal(d + 1), perm)::Int # first permuted dim of dest |
| 145 | + R2 = CartesianIndices(axes(src)[(d + 2):(d1 - 1)]) |
| 146 | + R3 = CartesianIndices(axes(src)[(d1 + 1):end]) |
| 147 | + _permutedims!(P, src, R1, R2, R3, d + 1, d1) |
| 148 | + end |
| 149 | + return P |
| 150 | +end |
| 151 | + |
| 152 | +@noinline function _permutedims!( |
| 153 | + P::NestedPermutedDimsArray, src, R1::CartesianIndices{0}, R2, R3, ds, dp |
| 154 | +) |
| 155 | + ip, is = axes(src, dp), axes(src, ds) |
| 156 | + for jo in first(ip):8:last(ip), io in first(is):8:last(is) |
| 157 | + for I3 in R3, I2 in R2 |
| 158 | + for j in jo:min(jo + 7, last(ip)) |
| 159 | + for i in io:min(io + 7, last(is)) |
| 160 | + @inbounds P[i, I2, j, I3] = src[i, I2, j, I3] |
| 161 | + end |
| 162 | + end |
| 163 | + end |
| 164 | + end |
| 165 | + return P |
| 166 | +end |
| 167 | + |
| 168 | +@noinline function _permutedims!(P::NestedPermutedDimsArray, src, R1, R2, R3, ds, dp) |
| 169 | + ip, is = axes(src, dp), axes(src, ds) |
| 170 | + for jo in first(ip):8:last(ip), io in first(is):8:last(is) |
| 171 | + for I3 in R3, I2 in R2 |
| 172 | + for j in jo:min(jo + 7, last(ip)) |
| 173 | + for i in io:min(io + 7, last(is)) |
| 174 | + for I1 in R1 |
| 175 | + @inbounds P[I1, i, I2, j, I3] = src[I1, i, I2, j, I3] |
| 176 | + end |
| 177 | + end |
| 178 | + end |
| 179 | + end |
| 180 | + end |
| 181 | + return P |
| 182 | +end |
| 183 | + |
| 184 | +const CommutativeOps = Union{ |
| 185 | + typeof(+), |
| 186 | + typeof(Base.add_sum), |
| 187 | + typeof(min), |
| 188 | + typeof(max), |
| 189 | + typeof(Base._extrema_rf), |
| 190 | + typeof(|), |
| 191 | + typeof(&), |
| 192 | +} |
| 193 | + |
| 194 | +function Base._mapreduce_dim( |
| 195 | + f, op::CommutativeOps, init::Base._InitialValue, A::NestedPermutedDimsArray, dims::Colon |
| 196 | +) |
| 197 | + return Base._mapreduce_dim(f, op, init, parent(A), dims) |
| 198 | +end |
| 199 | +function Base._mapreduce_dim( |
| 200 | + f::typeof(identity), |
| 201 | + op::Union{typeof(Base.mul_prod),typeof(*)}, |
| 202 | + init::Base._InitialValue, |
| 203 | + A::NestedPermutedDimsArray{<:Union{Real,Complex}}, |
| 204 | + dims::Colon, |
| 205 | +) |
| 206 | + return Base._mapreduce_dim(f, op, init, parent(A), dims) |
| 207 | +end |
| 208 | + |
| 209 | +function Base.mapreducedim!( |
| 210 | + f, op::CommutativeOps, B::AbstractArray{T,N}, A::NestedPermutedDimsArray{S,N,perm,iperm} |
| 211 | +) where {T,S,N,perm,iperm} |
| 212 | + C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output |
| 213 | + Base.mapreducedim!(f, op, C, parent(A)) |
| 214 | + return B |
| 215 | +end |
| 216 | +function Base.mapreducedim!( |
| 217 | + f::typeof(identity), |
| 218 | + op::Union{typeof(Base.mul_prod),typeof(*)}, |
| 219 | + B::AbstractArray{T,N}, |
| 220 | + A::NestedPermutedDimsArray{<:Union{Real,Complex},N,perm,iperm}, |
| 221 | +) where {T,N,perm,iperm} |
| 222 | + C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output |
| 223 | + Base.mapreducedim!(f, op, C, parent(A)) |
| 224 | + return B |
| 225 | +end |
| 226 | + |
| 227 | +function Base.showarg( |
| 228 | + io::IO, A::NestedPermutedDimsArray{T,N,perm}, toplevel |
| 229 | +) where {T,N,perm} |
| 230 | + print(io, "NestedPermutedDimsArray(") |
| 231 | + Base.showarg(io, parent(A), false) |
| 232 | + print(io, ", ", perm, ')') |
| 233 | + toplevel && print(io, " with eltype ", eltype(A)) |
| 234 | + return nothing |
| 235 | +end |
| 236 | + |
| 237 | +end |
0 commit comments