-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Simplify dropdims(::Transpose) and insertdims(::PermutedDimsArray) etc.
#55381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
Xref #45793 which added |
|
Pre-0.12 bump? After #56637 this would need to be split into two, but I won't bother if nobody will review it. |
|
We have moved the LinearAlgebra stdlib to an external repo: https://github.com/JuliaLang/LinearAlgebra.jl @mcabbott If you think that this PR is still relevant, please open a new PR on the LinearAlgebra.jl repo. |
|
No, only half of this PR is LinearAlgebra, the other half is Base. It can stay open for discussion? |
|
Yes, makes sense to re-open this PR for the parts that touch Base. |
| p - count(<=(p), innerdims) | ||
| end | ||
| newperm = ntuple(length(perm) - length(dims)) do d | ||
| i = d + count(<=(d), dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider
N = 4
perm = (2, 4, 1, 3)
dims = (1, 2)
so we want to keep outer dims (3, 4), which will be keeping inner dims (1, 3)
but this would give
innerdims = (2, 4)
innerperm = (1, 2, 1, 2)
newperm = (2, 2)
an invalid permutation, manifesting in examples like this
julia> A = reshape(collect(1:6), 2,1,3,1);
julia> dropdims(PermutedDimsArray(A, (2,4,1,3)); dims=(1,2))
this could work instead I think
kept = _sortedtuple_range_setdiff(1, N, sort(dims))
newperm = map(k -> innerperm[k], kept)
PermutedDimsArray(inner, newperm)
where I've defined _sortedtuple_range_setdiff like so as I guess we can avoid the collect into Vector from Base.setdiff
_sortedtuple_range_setdiff(a, b, ::Tuple{}) =
b < a ? () : ntuple(i -> i + a - 1, b - a + 1)
function _sortedtuple_range_setdiff(a, b, t::NTuple{N, T}) where {N,T}
x, rest = first(t), tail(t)
return if x < a
_sortedtuple_range_setdiff(a, b, rest)
elseif x > b
_sortedtuple_range_setdiff(a, b, ())
else
left = _sortedtuple_range_setdiff(a, x-1, ())
right = _sortedtuple_range_setdiff(x+1, b, rest)
(left..., right...)
end
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking! I need another coffee to understand where my logic went wrong, but good catch.
|
|
||
| function Base._insertdims(A::PermutedDimsArray{T,N,perm}, dims::NTuple{M,Int}) where {T,N,perm,M} | ||
| for i in eachindex(dims) | ||
| 1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| 1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1.")) | |
| 1 <= dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1.")) |
not to be too picky but just for consistency I'd use <= and >= everywhere (else: ≤ and ≥ everywhere)
|
besides the needed bugfix (in comment), this PR looks like a good idea to me and I'd be happy to continue to review it |
|
just spitballing: the fact that both |
Before, these wrap the wrapper in a ReshapedArray:
After:
This is not always faster, alone, but the hope is that the simpler return type will usually be faster downstream. (Especially with e.g.
CuArrays, where wrapping twice causes them to miss specialised methods.)It appears that the compiler can predict the type when
dimsis constant. More stringent tests (or ways to do better) would be welcome: