Skip to content

Commit c7c29e0

Browse files
committed
Update iterator protocol
1 parent 8728d49 commit c7c29e0

File tree

4 files changed

+14
-21
lines changed

4 files changed

+14
-21
lines changed

src/array_partition.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,8 @@ recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x))
216216

217217
## iteration
218218

219-
Base.start(A::ArrayPartition) = start(Chain(A.x))
220-
Base.next(A::ArrayPartition,state) = next(Chain(A.x),state)
221-
Base.done(A::ArrayPartition,state) = done(Chain(A.x),state)
219+
Base.iterate(A::ArrayPartition) = iterate(Chain(A.x))
220+
Base.iterate(A::ArrayPartition,state) = iterate(Chain(A.x),state)
222221

223222
Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
224223
Base.size(A::ArrayPartition) = (length(A),)

src/utils.jl

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,30 +145,24 @@ Base.length(it::Chain) = sum(length, it.xss)
145145

146146
Base.eltype(::Type{Chain{T}}) where {T} = typejoin([eltype(t) for t in T.parameters]...)
147147

148-
function Base.start(it::Chain)
148+
function Base.iterate(it::Chain)
149149
i = 1
150150
xs_state = nothing
151151
while i <= length(it.xss)
152-
xs_state = start(it.xss[i])
153-
if !done(it.xss[i], xs_state)
154-
break
155-
end
152+
xs_state = iterate(it.xss[i])
153+
xs_state !== nothing && return xs_state[1], (i, xs_state[2])
156154
i += 1
157155
end
158-
return i, xs_state
156+
return nothing
159157
end
160158

161-
function Base.next(it::Chain, state)
159+
function Base.iterate(it::Chain, state)
162160
i, xs_state = state
163-
v, xs_state = next(it.xss[i], xs_state)
164-
while done(it.xss[i], xs_state)
161+
xs_state = iterate(it.xss[i], xs_state)
162+
while xs_state == nothing
165163
i += 1
166-
if i > length(it.xss)
167-
break
168-
end
169-
xs_state = start(it.xss[i])
164+
i > length(it.xss) && return nothing
165+
xs_state = iterate(it.xss[i])
170166
end
171-
return v, (i, xs_state)
167+
return xs_state[1], (i, xs_state[2])
172168
end
173-
174-
Base.done(it::Chain, state) = state[1] > length(it.xss)

src/vector_of_array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ DiffEqArray(vec::AbstractVector,ts::AbstractVector) = DiffEqArray(vec, ts, (size
2323

2424
@inline Base.length(VA::AbstractVectorOfArray) = length(VA.u)
2525
@inline Base.eachindex(VA::AbstractVectorOfArray) = Base.OneTo(length(VA.u))
26-
@inline Base.iteratorsize(VA::AbstractVectorOfArray) = Base.HasLength()
26+
@inline Base.IteratorSize(VA::AbstractVectorOfArray) = Base.HasLength()
2727
# Linear indexing will be over the container elements, not the individual elements
2828
# unlike an true AbstractArray
2929
@inline Base.getindex(VA::AbstractVectorOfArray{T, N}, I::Int) where {T, N} = VA.u[I]

test/interface_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ testva2 = similar(testva)
5454
testval = 3.0
5555
fill!(testva2, testval)
5656
@test all(x->(x==testval), testva2)
57-
testts = rand(size(testva.u))
57+
testts = rand(Float64, size(testva.u))
5858
testda = DiffEqArray(recursivecopy(testva.u), testts)
5959
fill!(testda, testval)
6060
@test all(x->(x==testval), testda)

0 commit comments

Comments
 (0)