diff --git a/base/permuteddimsarray.jl b/base/permuteddimsarray.jl index 4e77d6b13ce21..476a74da250ed 100644 --- a/base/permuteddimsarray.jl +++ b/base/permuteddimsarray.jl @@ -363,6 +363,61 @@ function Base.mapreducedim!(f::typeof(identity), op::Union{typeof(Base.mul_prod) B end +function Base._dropdims(A::PermutedDimsArray{T,N,perm}, dims::Base.Dims) where {T,N,perm} + for d in dims + 1 <= d <= ndims(A) || throw(ArgumentError("dropped dims must be in range 1:ndims(A)")) + # dropdims also demands size(A,d)==1 and allunique(dims), checked by Base._dropdims below. + end + # Drop the appropriate dims of the parent array: + innerdims = map(d -> perm[d], dims) + inner = Base._dropdims(parent(A), innerdims) + # Change the permutation two ways: first account for dropdims(parent(A)), then skip entries at locations in dims. + innerperm = map(perm) do p + p - count(<=(p), innerdims) + end + newperm = ntuple(length(perm) - length(dims)) do d + i = d + count(<=(d), dims) + innerperm[i] + end + PermutedDimsArray(inner, newperm) +end +# Drop 1 dim of a matrix and you must get a vector, no need to wrap it: +function Base._dropdims(A::PermutedDimsArray{T,2,perm}, dims::Tuple{Int}) where {T,perm} + 1 <= only(dims) <= ndims(A) || throw(ArgumentError("dropped dims must be in range 1:ndims(A)")) + innerdim = perm[only(dims)] + Base._dropdims(parent(A), (innerdim,)) +end +# Drop all dims +function Base._dropdims(A::PermutedDimsArray{T,N,perm}, dims::NTuple{N,Int}) where {T,N,perm} + for d in dims + 1 <= d <= ndims(A) || throw(ArgumentError("dropped dims must be in range 1:ndims(A)")) + end + Base._dropdims(parent(A), dims) +end + +function Base._insertdims(A::PermutedDimsArray{T,N,perm}, dims::NTuple{M,Int}) where {T,N,perm,M} + for i in eachindex(dims) + 1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1.")) + dims[i] ≤ N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added")) + for j = 1:i-1 + dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique")) + end + end + # We can choose where to insert dims into parent array, choose the end? + innerdims = ntuple(d -> ndims(A) + d, length(dims)) + inner = Base._insertdims(parent(A), innerdims) + # With that choice, the new permutation just needs to insert higher numbers into sequence + newperm = ntuple(length(perm) + length(dims)) do d + c = count(<=(d), dims) + if d in dims + ndims(A) + c + else + perm[d - c] + end + end + PermutedDimsArray(inner, newperm) +end + function Base.showarg(io::IO, A::PermutedDimsArray{T,N,perm}, toplevel) where {T,N,perm} print(io, "PermutedDimsArray(") Base.showarg(io, parent(A), false) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index daee587b82835..f6d1f97493a63 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -388,6 +388,25 @@ hcat(avs::Adjoint{T,Vector{T}}...) where {T} = _adjoint_hcat(avs...) hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...) # TODO unify and allow mixed combinations +### dropdims +function Base._dropdims(A::Transpose{<:Number}, dims::Tuple{Int}) + if only(dims) == 1 + vec(parent(A)) + elseif only(dims) == 2 + Base._dropdims(parent(A), (1,)) + else + throw(ArgumentError("dropped dims must be in range 1:ndims(A)")) + end +end +function Base._dropdims(A::Adjoint{<:Real}, dims::Tuple{Int}) + if only(dims) == 1 + vec(parent(A)) + elseif only(dims) == 2 + Base._dropdims(parent(A), (1,)) + else + throw(ArgumentError("dropped dims must be in range 1:ndims(A)")) + end +end ### higher order functions # preserve Adjoint/Transpose wrapper around vectors diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index 1a66c7430723e..34e279104e1f7 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -325,6 +325,20 @@ end @test hcat(Transpose(vecvec), Transpose(vecvec))::Transpose{Transpose{Complex{Int},Vector{Complex{Int}}},Vector{Vector{Complex{Int}}}} == hcat(tvecvec, tvecvec) end +@testset "dropdims on Adjoint/Transpose-wrapped vectors & matrices" begin + intvec = [1, 2] + @test dropdims(Adjoint(intvec); dims=1) === intvec + @test dropdims(Transpose(intvec); dims=1) === intvec + cvec = [1.0 + 3im, 3.0 + 4im] + @test dropdims(Adjoint(cvec); dims=1) == conj(cvec) + @test dropdims(Transpose(cvec); dims=1) === cvec + intmat = [1 2 3] + @test dropdims(Adjoint(intmat); dims=2) == vec(intmat) + @test dropdims(Adjoint(intmat); dims=2) isa Vector + @test dropdims(Transpose(intmat); dims=2) == vec(intmat) + @test dropdims(Transpose(intmat); dims=2) isa Vector +end + @testset "map/broadcast over Adjoint/Transpose-wrapped vectors and Numbers" begin # map and broadcast over Adjoint/Transpose-wrapped vectors and Numbers # should preserve the Adjoint/Transpose-wrapper to preserve semantics downstream diff --git a/test/arrayops.jl b/test/arrayops.jl index f58fdb36942a2..3a8ad3081f707 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -378,6 +378,62 @@ end @test isequal(reshape(reshape(1:27, 3, 3, 3), Val(2))[1,:], [1, 4, 7, 10, 13, 16, 19, 22, 25]) end + +@testset "dropdims/insertdims on PermutedDimsArray" begin + # Matrix + P1 = PermutedDimsArray(randn(5,1), (2,1)) + M1 = collect(P1) + @test dropdims(P1; dims=1) == dropdims(M1; dims=1) + @test insertdims(P1; dims=1) == insertdims(M1; dims=1) + + @test dropdims(P1; dims=1) isa Vector + @test insertdims(P1; dims=1) isa PermutedDimsArray + + @test_throws ArgumentError dropdims(P1; dims=0) + @test_throws ArgumentError dropdims(P1; dims=2) + @test_throws ArgumentError dropdims(P1; dims=3) + @test_throws ArgumentError dropdims(P1; dims=(1,1)) + + @test_throws ArgumentError insertdims(P1; dims=0) + @test_throws ArgumentError insertdims(P1; dims=4) + @test_throws ArgumentError insertdims(P1; dims=(1,1)) + + # More dims + P2 = PermutedDimsArray(randn(3,1,4), (3,2,1)) + M2 = collect(P2) + @test dropdims(P2; dims=2) == dropdims(M2; dims=2) + @test insertdims(P2; dims=2) == insertdims(M2; dims=2) + + P13 = PermutedDimsArray(randn(2,1,3,1,4), (4,5,2,1,3)) + A13 = collect(P13) + @test dropdims(P13; dims=3) == dropdims(A13; dims=3) + @test dropdims(P13; dims=(1,3)) == dropdims(A13; dims=(1,3)) + @test dropdims(P13; dims=(3,1)) == dropdims(A13; dims=(3,1)) + + @test insertdims(P13; dims=2) == insertdims(A13; dims=2) + @test insertdims(P13; dims=(4,6)) == insertdims(A13; dims=(4,6)) + @test insertdims(P13; dims=(4,1)) == insertdims(A13; dims=(4,1)) + + @test dropdims(P13; dims=(1,3)) isa PermutedDimsArray + @test insertdims(P13; dims=(4,6)) isa PermutedDimsArray + + @test_throws ArgumentError dropdims(P13; dims=0) + @test_throws ArgumentError dropdims(P13; dims=2) + @test_throws ArgumentError dropdims(P13; dims=4) + @test_throws ArgumentError dropdims(P13; dims=(3,3)) + + # Zero-dim cases + p1 = PermutedDimsArray(rand(1), (1,)) + @test dropdims(p1; dims=1) == dropdims(collect(p1); dims=1) + p12 = PermutedDimsArray(rand(1,1), (2,1)) + @test dropdims(p12; dims=2) == dropdims(collect(p12); dims=2) + @test dropdims(p12; dims=(1,2)) == dropdims(collect(p12); dims=(1,2)) + a = fill(rand()) + p = PermutedDimsArray(a, ()) + @test insertdims(p, dims=1) == insertdims(a, dims=1) + @test insertdims(p, dims=(1,2)) == insertdims(a, dims=(1,2)) +end + @testset "find(in(b), a)" begin # unsorted inputs a = [3, 5, -7, 6]