Skip to content

Commit 875c29c

Browse files
mcabbottadienes
andauthored
Fix stack(; dims) on containers with HasLength eltype & HasShape elements (#56777)
Fixes #56771 While fixing this I saw #56776, will comment there on differences. I think the reason this line existed was probably to handle `Tuple`, for which `Base.IteratorSize(Tuple) == Base.HasLength()` although they behave like `HasShape{1}`. I'd like to check a little more for unintended consequences, hence mark this draft for now. Co-authored-by: Andy Dienes <[email protected]>
1 parent 777c5e3 commit 875c29c

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

base/abstractarray.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2939,8 +2939,6 @@ _iterator_axes(x, ::IteratorSize) = axes(x)
29392939
# For some dims values, stack(A; dims) == stack(vec(A)), and the : path will be faster
29402940
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S} =
29412941
_typed_stack(dims, T, S, IteratorSize(S), A)
2942-
_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::HasLength, A) where {T,S} =
2943-
_typed_stack(dims, T, S, HasShape{1}(), A)
29442942
function _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::HasShape{N}, A) where {T,S,N}
29452943
if dims == N+1
29462944
_typed_stack(:, T, S, A, (_vec_axis(A),))

test/abstractarray.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,6 +1880,33 @@ end
18801880
end
18811881
end
18821882

1883+
@testset "issue 56771, stack(; dims) on containers with HasLength eltype & HasShape elements" begin
1884+
for T in (Matrix, Array, Any)
1885+
xs = T[rand(2,3) for _ in 1:4]
1886+
@test size(stack(xs; dims=1)) == (4,2,3)
1887+
@test size(stack(xs; dims=2)) == (2,4,3) # this was the problem case, for T=Array
1888+
@test size(stack(xs; dims=3)) == (2,3,4)
1889+
@test size(stack(identity, xs; dims=2)) == (2,4,3)
1890+
@test size(stack(x for x in xs if true; dims=2)) == (2,4,3)
1891+
1892+
xmat = T[rand(2,3) for _ in 1:4, _ in 1:5]
1893+
@test size(stack(xmat; dims=1)) == (20,2,3)
1894+
@test size(stack(xmat; dims=2)) == (2,20,3)
1895+
@test size(stack(xmat; dims=3)) == (2,3,20)
1896+
end
1897+
1898+
it = Iterators.product(1:2, 3:5)
1899+
@test size(it) == (2,3)
1900+
@test Base.IteratorSize(typeof(it)) == Base.HasShape{2}()
1901+
@test Base.IteratorSize(Iterators.ProductIterator) == Base.HasLength()
1902+
for T in (typeof(it), Iterators.ProductIterator, Any)
1903+
ys = T[it for _ in 1:4]
1904+
@test size(stack(ys; dims=2)) == (2,4,3)
1905+
@test size(stack(identity, ys; dims=2)) == (2,4,3)
1906+
@test size(stack(y for y in ys if true; dims=2)) == (2,4,3)
1907+
end
1908+
end
1909+
18831910
@testset "keepat!" begin
18841911
a = [1:6;]
18851912
@test a === keepat!(a, 1:5)

0 commit comments

Comments
 (0)