Skip to content

Commit 8364a4c

Browse files
authored
Fix collect on stateful generator (#41919)
Previously this code would drop 1 from the length of some generators. Fixes #35530
1 parent bb8e77e commit 8364a4c

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

base/array.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -758,21 +758,25 @@ else
758758
end
759759
end
760760

761-
_array_for(::Type{T}, itr, ::HasLength) where {T} = Vector{T}(undef, Int(length(itr)::Integer))
762-
_array_for(::Type{T}, itr, ::HasShape{N}) where {T,N} = similar(Array{T,N}, axes(itr))
761+
_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
762+
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
763+
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
764+
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)
763765

764766
function collect(itr::Generator)
765767
isz = IteratorSize(itr.iter)
766768
et = @default_eltype(itr)
767769
if isa(isz, SizeUnknown)
768770
return grow_to!(Vector{et}(), itr)
769771
else
772+
shape = isz isa HasLength ? length(itr) : axes(itr)
770773
y = iterate(itr)
771774
if y === nothing
772775
return _array_for(et, itr.iter, isz)
773776
end
774777
v1, st = y
775-
collect_to_with_first!(_array_for(typeof(v1), itr.iter, isz), v1, itr, st)
778+
arr = _array_for(typeof(v1), itr.iter, isz, shape)
779+
return collect_to_with_first!(arr, v1, itr, st)
776780
end
777781
end
778782

test/iterators.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,15 @@ let (a, b) = (1:3, [4 6;
292292
end
293293
end
294294

295+
# collect stateful iterator
296+
let
297+
itr = (i+1 for i in Base.Stateful([1,2,3]))
298+
@test collect(itr) == [2, 3, 4]
299+
A = zeros(Int, 0, 0)
300+
itr = (i-1 for i in Base.Stateful(A))
301+
@test collect(itr) == Int[] # Stateful do not preserve shape
302+
end
303+
295304
# with 1D inputs
296305
let a = 1:2,
297306
b = 1.0:10.0,
@@ -860,4 +869,4 @@ end
860869
@test Iterators.peel(1:10)[2] |> collect == 2:10
861870
@test Iterators.peel(x^2 for x in 2:4)[1] == 4
862871
@test Iterators.peel(x^2 for x in 2:4)[2] |> collect == [9, 16]
863-
end
872+
end

0 commit comments

Comments
 (0)