Skip to content

Conversation

@mcabbott
Copy link
Contributor

@mcabbott mcabbott commented Aug 5, 2024

Before, these wrap the wrapper in a ReshapedArray:

julia> dropdims([1,2,3]'; dims=1)
3-element reshape(adjoint(::Vector{Int64}), 3) with eltype Int64:
 1
 2
 3

julia> insertdims(PermutedDimsArray([1 2; 3 4], (2,1)); dims=2)
2×1×2 reshape(PermutedDimsArray(::Matrix{Int64}, (2, 1)), 2, 1, 2) with eltype Int64:
[:, :, 1] =
 1
 2

[:, :, 2] =
 3
 4

After:

julia> dropdims([1,2,3]'; dims=1)
3-element Vector{Int64}:
 1
 2
 3

julia> insertdims(PermutedDimsArray([1 2; 3 4], (2,1)); dims=2)
2×1×2 PermutedDimsArray(::Array{Int64, 3}, (2, 3, 1)) with eltype Int64:
[:, :, 1] =
 1
 2

[:, :, 2] =
 3
 4

This is not always faster, alone, but the hope is that the simpler return type will usually be faster downstream. (Especially with e.g. CuArrays, where wrapping twice causes them to miss specialised methods.)

julia> P13 = PermutedDimsArray(randn(2,1,3,1,4), (4,5,2,1,3));

julia> @btime dropdims($P13; dims=(1,3));
  189.628 ns (10 allocations: 192 bytes)  # before, ReshapedArray(PermutedDimsArray(Array{Float64, 5}))
  272.895 ns (10 allocations: 256 bytes)  # after, PermutedDimsArray(Array{Float64, 3})

julia> A13 = collect(P13);

julia> @btime dropdims($A13; dims=(1,3));
  172.333 ns (11 allocations: 240 bytes)

It appears that the compiler can predict the type when dims is constant. More stringent tests (or ways to do better) would be welcome:

julia> drop1(x) = dropdims(x; dims=(1,));

julia> @code_warntype drop1(P13)  # before
MethodInstance for drop1(::PermutedDimsArray{Float64, 5, (4, 5, 2, 1, 3), (4, 3, 5, 1, 2), Array{Float64, 5}})
  from drop1(x) @ Main REPL[12]:1
Arguments
  #self#::Core.Const(drop1)
  x::PermutedDimsArray{Float64, 5, (4, 5, 2, 1, 3), (4, 3, 5, 1, 2), Array{Float64, 5}}
Body::Base.ReshapedArray{Float64, 4, PermutedDimsArray{Float64, 5, (4, 5, 2, 1, 3), (4, 3, 5, 1, 2), Array{Float64, 5}}, NTuple{4, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}
1%1 = (:dims,)::Core.Const((:dims,))
...

julia> @code_warntype drop1(P13)  # after
MethodInstance for drop1(::PermutedDimsArray{Float64, 5, (4, 5, 2, 1, 3), (4, 3, 5, 1, 2), Array{Float64, 5}})
  from drop1(x) @ Main REPL[240]:1
Arguments
  #self#::Core.Const(drop1)
  x::PermutedDimsArray{Float64, 5, (4, 5, 2, 1, 3), (4, 3, 5, 1, 2), Array{Float64, 5}}
Body::PermutedDimsArray{Float64, 4, (4, 2, 1, 3), (3, 2, 4, 1), Array{Float64, 4}}
1%1 = (:dims,)::Core.Const((:dims,))
...

@mcabbott mcabbott added the arrays [a, r, r, a, y, s] label Aug 5, 2024
@mcabbott mcabbott requested a review from nsajko August 23, 2024 17:08
@mcabbott mcabbott requested a review from mbauman August 23, 2024 17:10
@mcabbott
Copy link
Contributor Author

Xref #45793 which added insertdims. And I requested review from a few of those who commented on that PR.

@mcabbott
Copy link
Contributor Author

mcabbott commented Jan 6, 2025

Pre-0.12 bump?

After #56637 this would need to be split into two, but I won't bother if nobody will review it.

@mcabbott mcabbott added the linear algebra Linear algebra label Jan 6, 2025
@DilumAluthge
Copy link
Member

We have moved the LinearAlgebra stdlib to an external repo: https://github.com/JuliaLang/LinearAlgebra.jl

@mcabbott If you think that this PR is still relevant, please open a new PR on the LinearAlgebra.jl repo.

@mcabbott
Copy link
Contributor Author

No, only half of this PR is LinearAlgebra, the other half is Base. It can stay open for discussion?

@mcabbott mcabbott reopened this Jan 13, 2025
@DilumAluthge
Copy link
Member

Yes, makes sense to re-open this PR for the parts that touch Base.

p - count(<=(p), innerdims)
end
newperm = ntuple(length(perm) - length(dims)) do d
i = d + count(<=(d), dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider

N = 4
perm = (2, 4, 1, 3)
dims = (1, 2) 

so we want to keep outer dims (3, 4), which will be keeping inner dims (1, 3)

but this would give

innerdims = (2, 4)
innerperm = (1, 2, 1, 2)
newperm = (2, 2)

an invalid permutation, manifesting in examples like this

julia> A = reshape(collect(1:6), 2,1,3,1);

julia> dropdims(PermutedDimsArray(A, (2,4,1,3)); dims=(1,2))

this could work instead I think

    kept = _sortedtuple_range_setdiff(1, N, sort(dims))
    newperm = map(k -> innerperm[k], kept)
    PermutedDimsArray(inner, newperm)

where I've defined _sortedtuple_range_setdiff like so as I guess we can avoid the collect into Vector from Base.setdiff

_sortedtuple_range_setdiff(a, b, ::Tuple{}) =
    b < a ? () : ntuple(i -> i + a - 1, b - a + 1)
function _sortedtuple_range_setdiff(a, b, t::NTuple{N, T}) where {N,T}
    x, rest = first(t), tail(t)
    return if x < a
        _sortedtuple_range_setdiff(a, b, rest)
    elseif x > b
        _sortedtuple_range_setdiff(a, b, ())
    else
        left = _sortedtuple_range_setdiff(a, x-1, ())
        right = _sortedtuple_range_setdiff(x+1, b, rest)
        (left..., right...)
    end
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking! I need another coffee to understand where my logic went wrong, but good catch.


function Base._insertdims(A::PermutedDimsArray{T,N,perm}, dims::NTuple{M,Int}) where {T,N,perm,M}
for i in eachindex(dims)
1 dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
1 dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))
1 <= dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))

not to be too picky but just for consistency I'd use <= and >= everywhere (else: and everywhere)

@adienes
Copy link
Member

adienes commented Aug 17, 2025

besides the needed bugfix (in comment), this PR looks like a good idea to me and I'd be happy to continue to review it

@adienes
Copy link
Member

adienes commented Sep 3, 2025

just spitballing: the fact that both dropdims and insertdims forward to reshape suggest that maybe reshape(::PermutedDimsArray) should be specialized rather than each of dropdims and insertdims. also reshape is not required to return an array of the same wrapper type as its input, so maybe we could get away with actually just returning a reshape of a view of the parent

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

arrays [a, r, r, a, y, s] linear algebra Linear algebra

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants