Skip to content

Commit ca62af9

Browse files
committed
add test
1. add test for non-1 step case 2. fix compatibility with 1.6
1 parent 1a2650f commit ca62af9

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

base/multidimensional.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -621,16 +621,15 @@ module IteratorsMD
621621
size1 = size(piter, 1)
622622
vindʷ = only(iter.indices)
623623
vindᵒ = cld(first(vindʷ), size1) : cld(last(vindʷ), size1)
624-
outer = @inbounds view(ci, vindᵒ)
625624
# Use Generator to make inner loop branchless
626-
Base.Generator(Iterators.enumerate(outer)) do (i, I)
627-
@inline
625+
@inline function genouter(i::Int, I::CartesianIndex)
628626
l, r = first(vindᵒ), last(vindᵒ)
629627
skip = i == 1 ? first(vindʷ) - 1 - (l - 1) * size1 : 0
630628
len = i == length(vindᵒ) ? last(vindʷ) - (r - 1) * size1 : size1
631-
len -= skip
632-
skip, len, I
629+
skip, len - skip, I
633630
end
631+
outer = @inbounds view(ci, vindᵒ)
632+
(genouter(i, I) for (i, I) in Iterators.enumerate(outer))
634633
end
635634
@inline simd_inner_length(iter::CartesianPartition, (skip, len, I)::Tuple{Int,Int,CartesianIndex}) = len
636635
@inline function simd_index(iter::CartesianPartition, (skip, len, I)::Tuple{Int,Int,CartesianIndex}, n::Int)

test/iterators.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,12 +548,14 @@ end
548548
(1,1), (8,8), (11, 13),
549549
(1,1,1), (8, 4, 2), (11, 13, 17)),
550550
part in (1, 7, 8, 11, 63, 64, 65, 142, 143, 144)
551-
P = partition(CartesianIndices(dims), part)
552-
for I in P
553-
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
554-
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
551+
for iter in (CartesianIndices(dims), CartesianIndices(map(d -> 1:2:2d, dims)))
552+
P = partition(iter, part)
553+
for I in P
554+
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
555+
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
556+
end
557+
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), iter))
555558
end
556-
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), CartesianIndices(dims)))
557559
end
558560
@testset "empty/invalid partitions" begin
559561
@test_throws ArgumentError partition(1:10, 0)

0 commit comments

Comments
 (0)