Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions base/permuteddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,61 @@ function Base.mapreducedim!(f::typeof(identity), op::Union{typeof(Base.mul_prod)
B
end

function Base._dropdims(A::PermutedDimsArray{T,N,perm}, dims::Base.Dims) where {T,N,perm}
for d in dims
1 <= d <= ndims(A) || throw(ArgumentError("dropped dims must be in range 1:ndims(A)"))
# dropdims also demands size(A,d)==1 and allunique(dims), checked by Base._dropdims below.
end
# Drop the appropriate dims of the parent array:
innerdims = map(d -> perm[d], dims)
inner = Base._dropdims(parent(A), innerdims)
# Change the permutation two ways: first account for dropdims(parent(A)), then skip entries at locations in dims.
innerperm = map(perm) do p
p - count(<=(p), innerdims)
end
newperm = ntuple(length(perm) - length(dims)) do d
i = d + count(<=(d), dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider

N = 4
perm = (2, 4, 1, 3)
dims = (1, 2) 

so we want to keep outer dims (3, 4), which will be keeping inner dims (1, 3)

but this would give

innerdims = (2, 4)
innerperm = (1, 2, 1, 2)
newperm = (2, 2)

an invalid permutation, manifesting in examples like this

julia> A = reshape(collect(1:6), 2,1,3,1);

julia> dropdims(PermutedDimsArray(A, (2,4,1,3)); dims=(1,2))

this could work instead I think

    kept = _sortedtuple_range_setdiff(1, N, sort(dims))
    newperm = map(k -> innerperm[k], kept)
    PermutedDimsArray(inner, newperm)

where I've defined _sortedtuple_range_setdiff like so as I guess we can avoid the collect into Vector from Base.setdiff

_sortedtuple_range_setdiff(a, b, ::Tuple{}) =
    b < a ? () : ntuple(i -> i + a - 1, b - a + 1)
function _sortedtuple_range_setdiff(a, b, t::NTuple{N, T}) where {N,T}
    x, rest = first(t), tail(t)
    return if x < a
        _sortedtuple_range_setdiff(a, b, rest)
    elseif x > b
        _sortedtuple_range_setdiff(a, b, ())
    else
        left = _sortedtuple_range_setdiff(a, x-1, ())
        right = _sortedtuple_range_setdiff(x+1, b, rest)
        (left..., right...)
    end
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking! I need another coffee to understand where my logic went wrong, but good catch.

innerperm[i]
end
PermutedDimsArray(inner, newperm)
end
# Drop 1 dim of a matrix and you must get a vector, no need to wrap it:
function Base._dropdims(A::PermutedDimsArray{T,2,perm}, dims::Tuple{Int}) where {T,perm}
1 <= only(dims) <= ndims(A) || throw(ArgumentError("dropped dims must be in range 1:ndims(A)"))
innerdim = perm[only(dims)]
Base._dropdims(parent(A), (innerdim,))
end
# Drop all dims
function Base._dropdims(A::PermutedDimsArray{T,N,perm}, dims::NTuple{N,Int}) where {T,N,perm}
for d in dims
1 <= d <= ndims(A) || throw(ArgumentError("dropped dims must be in range 1:ndims(A)"))
end
Base._dropdims(parent(A), dims)
end

function Base._insertdims(A::PermutedDimsArray{T,N,perm}, dims::NTuple{M,Int}) where {T,N,perm,M}
for i in eachindex(dims)
1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
1 dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))
1 <= dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))

not to be too picky but just for consistency I'd use <= and >= everywhere (else: and everywhere)

dims[i] ≤ N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added"))
for j = 1:i-1
dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique"))
end
end
# We can choose where to insert dims into parent array, choose the end?
innerdims = ntuple(d -> ndims(A) + d, length(dims))
inner = Base._insertdims(parent(A), innerdims)
# With that choice, the new permutation just needs to insert higher numbers into sequence
newperm = ntuple(length(perm) + length(dims)) do d
c = count(<=(d), dims)
if d in dims
ndims(A) + c
else
perm[d - c]
end
end
PermutedDimsArray(inner, newperm)
end

function Base.showarg(io::IO, A::PermutedDimsArray{T,N,perm}, toplevel) where {T,N,perm}
print(io, "PermutedDimsArray(")
Base.showarg(io, parent(A), false)
Expand Down
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,25 @@ hcat(avs::Adjoint{T,Vector{T}}...) where {T} = _adjoint_hcat(avs...)
hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...)
# TODO unify and allow mixed combinations

### dropdims
function Base._dropdims(A::Transpose{<:Number}, dims::Tuple{Int})
if only(dims) == 1
vec(parent(A))
elseif only(dims) == 2
Base._dropdims(parent(A), (1,))
else
throw(ArgumentError("dropped dims must be in range 1:ndims(A)"))
end
end
function Base._dropdims(A::Adjoint{<:Real}, dims::Tuple{Int})
if only(dims) == 1
vec(parent(A))
elseif only(dims) == 2
Base._dropdims(parent(A), (1,))
else
throw(ArgumentError("dropped dims must be in range 1:ndims(A)"))
end
end

### higher order functions
# preserve Adjoint/Transpose wrapper around vectors
Expand Down
14 changes: 14 additions & 0 deletions stdlib/LinearAlgebra/test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,20 @@ end
@test hcat(Transpose(vecvec), Transpose(vecvec))::Transpose{Transpose{Complex{Int},Vector{Complex{Int}}},Vector{Vector{Complex{Int}}}} == hcat(tvecvec, tvecvec)
end

@testset "dropdims on Adjoint/Transpose-wrapped vectors & matrices" begin
intvec = [1, 2]
@test dropdims(Adjoint(intvec); dims=1) === intvec
@test dropdims(Transpose(intvec); dims=1) === intvec
cvec = [1.0 + 3im, 3.0 + 4im]
@test dropdims(Adjoint(cvec); dims=1) == conj(cvec)
@test dropdims(Transpose(cvec); dims=1) === cvec
intmat = [1 2 3]
@test dropdims(Adjoint(intmat); dims=2) == vec(intmat)
@test dropdims(Adjoint(intmat); dims=2) isa Vector
@test dropdims(Transpose(intmat); dims=2) == vec(intmat)
@test dropdims(Transpose(intmat); dims=2) isa Vector
end

@testset "map/broadcast over Adjoint/Transpose-wrapped vectors and Numbers" begin
# map and broadcast over Adjoint/Transpose-wrapped vectors and Numbers
# should preserve the Adjoint/Transpose-wrapper to preserve semantics downstream
Expand Down
56 changes: 56 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,62 @@ end

@test isequal(reshape(reshape(1:27, 3, 3, 3), Val(2))[1,:], [1, 4, 7, 10, 13, 16, 19, 22, 25])
end

@testset "dropdims/insertdims on PermutedDimsArray" begin
# Matrix
P1 = PermutedDimsArray(randn(5,1), (2,1))
M1 = collect(P1)
@test dropdims(P1; dims=1) == dropdims(M1; dims=1)
@test insertdims(P1; dims=1) == insertdims(M1; dims=1)

@test dropdims(P1; dims=1) isa Vector
@test insertdims(P1; dims=1) isa PermutedDimsArray

@test_throws ArgumentError dropdims(P1; dims=0)
@test_throws ArgumentError dropdims(P1; dims=2)
@test_throws ArgumentError dropdims(P1; dims=3)
@test_throws ArgumentError dropdims(P1; dims=(1,1))

@test_throws ArgumentError insertdims(P1; dims=0)
@test_throws ArgumentError insertdims(P1; dims=4)
@test_throws ArgumentError insertdims(P1; dims=(1,1))

# More dims
P2 = PermutedDimsArray(randn(3,1,4), (3,2,1))
M2 = collect(P2)
@test dropdims(P2; dims=2) == dropdims(M2; dims=2)
@test insertdims(P2; dims=2) == insertdims(M2; dims=2)

P13 = PermutedDimsArray(randn(2,1,3,1,4), (4,5,2,1,3))
A13 = collect(P13)
@test dropdims(P13; dims=3) == dropdims(A13; dims=3)
@test dropdims(P13; dims=(1,3)) == dropdims(A13; dims=(1,3))
@test dropdims(P13; dims=(3,1)) == dropdims(A13; dims=(3,1))

@test insertdims(P13; dims=2) == insertdims(A13; dims=2)
@test insertdims(P13; dims=(4,6)) == insertdims(A13; dims=(4,6))
@test insertdims(P13; dims=(4,1)) == insertdims(A13; dims=(4,1))

@test dropdims(P13; dims=(1,3)) isa PermutedDimsArray
@test insertdims(P13; dims=(4,6)) isa PermutedDimsArray

@test_throws ArgumentError dropdims(P13; dims=0)
@test_throws ArgumentError dropdims(P13; dims=2)
@test_throws ArgumentError dropdims(P13; dims=4)
@test_throws ArgumentError dropdims(P13; dims=(3,3))

# Zero-dim cases
p1 = PermutedDimsArray(rand(1), (1,))
@test dropdims(p1; dims=1) == dropdims(collect(p1); dims=1)
p12 = PermutedDimsArray(rand(1,1), (2,1))
@test dropdims(p12; dims=2) == dropdims(collect(p12); dims=2)
@test dropdims(p12; dims=(1,2)) == dropdims(collect(p12); dims=(1,2))
a = fill(rand())
p = PermutedDimsArray(a, ())
@test insertdims(p, dims=1) == insertdims(a, dims=1)
@test insertdims(p, dims=(1,2)) == insertdims(a, dims=(1,2))
end

@testset "find(in(b), a)" begin
# unsorted inputs
a = [3, 5, -7, 6]
Expand Down