Skip to content

Commit dda0aa8

Browse files
committed
add test & optimize performance
1. add more test cases for widen definition 2. make ind2sub faster via proper fallback
1 parent b1fdb73 commit dda0aa8

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

base/multidimensional.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -610,25 +610,27 @@ module IteratorsMD
610610
return I, (I, n+1)
611611
end
612612

613+
# make _ind2sub faster when possible
614+
_ind2sub_rs(axp, mi::Tuple, i) = Base.ind2sub_rs(axp, mi, i)
615+
_ind2sub_rs(axp, mi::Tuple{}, i) = Base._ind2sub(axp, i)
616+
_ind2sub_rs(axp::Base.Indices{1}, mi::Tuple{}, i) = (first(axp[1]) + i - 1,)
617+
@inline simd_outer_range(iter::CartesianPartition{<:CartesianIndex{0}}) = ((0,1,CartesianIndex()),)
613618
@inline function simd_outer_range(iter::CartesianPartition)
614619
# CartesianPartition might start and stop in the middle of the outer
615620
# dimensions, thus the outer range itself is a CartesianPartition.
616621
piter = iter.parent.parent
617622
ax1, oiter = split(piter, Val(1))
618623
vindʷ = only(iter.indices)
619-
@inline function _splitlinear(i::Int)
620-
ci = Base._to_subscript_indices(piter, i)
624+
function _splitlinear(i::Int)
625+
ci = _ind2sub_rs(axes(piter), iter.parent.mi, i)
621626
ci[1], Base._to_linear_index(oiter, tail(ci)...)
622627
end
623628
fl, vl = _splitlinear(first(vindʷ))
624629
fr, vr = _splitlinear(last(vindʷ))
625-
@inline function _view(oiter::CartesianIndices, ind::UnitRange)
626-
# we dont have #40344 for 1.6 and 1.7, force this return a CartesianIndices{1}
627-
ndims(oiter) == 1 && return @inbounds CartesianIndices((oiter.indices[1][ind],))
628-
# there's no need to make outer range fast-indexable
629-
oiter′ = ReshapedArray(oiter, (length(oiter),), ())
630-
@inbounds view(oiter′, ind)
631-
end
630+
# we dont have #40344 for 1.6 and 1.7, force this return a CartesianIndices{1}
631+
_view(x::CartesianIndices{1}, ind::UnitRange) = @inbounds CartesianIndices((x.indices[1][ind],))
632+
# there's no need to make outer range fast-indexable
633+
_view(x::CartesianIndices, ind::UnitRange) = @inbounds view(ReshapedArray(x, (length(x),), ()), ind)
632634
outer = _view(oiter, vl:vr)
633635
# Use Generator to make inner loop branchless
634636
@inline function genouter(i::Int, I::CartesianIndex)

test/iterators.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ function index_elements(iter)
544544
return vals
545545
end
546546

547-
@testset "CartesianPartition optimizations" for dims in ((1,), (64,), (101,),
547+
@testset "CartesianPartition optimizations" for dims in ((), (1,), (64,), (101,),
548548
(1,1), (8,8), (11, 13),
549549
(1,1,1), (8, 4, 2), (11, 13, 17)),
550550
part in (1, 7, 8, 11, 63, 64, 65, 142, 143, 144)
@@ -556,6 +556,13 @@ end
556556
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
557557
end
558558
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), iter))
559+
iter′ = Base.ReshapedArray(iter, (length(iter),), ())
560+
P′ = partition(iter′, part)
561+
for I in P′
562+
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
563+
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
564+
end
565+
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P′)), iter′))
559566
end
560567
end
561568
@testset "empty/invalid partitions" begin

0 commit comments

Comments
 (0)