Skip to content

Commit 96f8156

Browse files
authored
Simplify state in SubBlockIterator (#353)
* Simplify state in SubBlockIterator * Simplify SubBlockIterator docstring
1 parent 772c307 commit 96f8156

File tree

1 file changed

+17
-37
lines changed

1 file changed

+17
-37
lines changed

src/blockbroadcast.jl

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ similar(bc::Broadcasted{PseudoBlockStyle{N}}, ::Type{T}) where {T,N} =
4949
SubBlockIterator(subblock_lasts::Vector{Int}, block_lasts::Vector{Int})
5050
SubBlockIterator(A::AbstractArray, bs::NTuple{N,AbstractUnitRange{Int}} where N, dim::Integer)
5151
52-
An iterator for iterating `BlockIndexRange` of the blocks specified by
52+
Return an iterator over the `BlockIndexRange`s of the blocks specified by
5353
`subblock_lasts`. The `Block` index part of `BlockIndexRange` is
5454
determined by `subblock_lasts`. That is to say, the `Block` index first
5555
specifies one of the block represented by `subblock_lasts` and then the
@@ -63,41 +63,27 @@ be ensured by the caller.
6363
```jldoctest
6464
julia> using BlockArrays
6565
66-
julia> import BlockArrays: SubBlockIterator, BlockIndexRange
66+
julia> import BlockArrays: SubBlockIterator
6767
6868
julia> A = BlockArray(1:6, 1:3);
6969
70-
julia> subblock_lasts = blocklasts(axes(A, 1));
71-
72-
julia> @assert subblock_lasts == [1, 3, 6];
70+
julia> subblock_lasts = blocklasts(axes(A, 1))
71+
3-element ArrayLayouts.RangeCumsum{Int64, UnitRange{Int64}}:
72+
1
73+
3
74+
6
7375
7476
julia> block_lasts = [1, 3, 4, 6];
7577
76-
julia> for idx in SubBlockIterator(subblock_lasts, block_lasts)
77-
B = @show view(A, idx)
78-
@assert !(parent(B) isa BlockArray)
79-
idx :: BlockIndexRange
80-
idx.block :: Block{1}
81-
idx.indices :: Tuple{UnitRange}
82-
end
83-
view(A, idx) = 1:1
84-
view(A, idx) = 2:3
85-
view(A, idx) = 4:4
86-
view(A, idx) = 5:6
87-
88-
julia> [idx.block.n[1] for idx in SubBlockIterator(subblock_lasts, block_lasts)]
89-
4-element Vector{Int64}:
90-
1
91-
2
92-
3
93-
3
78+
julia> itr = SubBlockIterator(subblock_lasts, block_lasts)
79+
SubBlockIterator([1, 3, 6], [1, 3, 4, 6])
9480
95-
julia> [idx.indices[1] for idx in SubBlockIterator(subblock_lasts, block_lasts)]
96-
4-element Vector{UnitRange{Int64}}:
97-
1:1
98-
1:2
99-
1:1
100-
2:3
81+
julia> collect(itr)
82+
4-element Vector{BlockArrays.BlockIndexRange{1, Tuple{UnitRange{Int64}}}}:
83+
Block(1)[1:1]
84+
Block(2)[1:2]
85+
Block(3)[1:1]
86+
Block(3)[2:3]
10187
```
10288
"""
10389
struct SubBlockIterator
@@ -114,15 +100,9 @@ Base.length(it::SubBlockIterator) = length(it.block_lasts)
114100
SubBlockIterator(arr::AbstractArray, bs::NTuple{N,AbstractUnitRange{Int}}, dim::Integer) where N =
115101
SubBlockIterator(blocklasts(axes(arr, dim)), blocklasts(bs[dim]))
116102

117-
function Base.iterate(it::SubBlockIterator, state=nothing)
118-
if state === nothing
119-
i,j = 1,1
120-
else
121-
i, j = state
122-
end
123-
length(it.block_lasts)+1 == i && return nothing
103+
function Base.iterate(it::SubBlockIterator, (i, j) = (1,1))
104+
i > length(it.block_lasts) && return nothing
124105
idx = i == 1 ? (1:it.block_lasts[i]) : (it.block_lasts[i-1]+1:it.block_lasts[i])
125-
126106
bir = Block(j)[j == 1 ? idx : idx .- it.subblock_lasts[j-1]]
127107
if it.subblock_lasts[j] == it.block_lasts[i]
128108
j += 1

0 commit comments

Comments
 (0)