Skip to content

Commit 4b6d78d

Browse files
committed
code clean
make `_ind2sub_rs` fallback to existing `Base` index transformation
1 parent dda0aa8 commit 4b6d78d

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

base/multidimensional.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -610,28 +610,27 @@ module IteratorsMD
610610
return I, (I, n+1)
611611
end
612612

613-
# make _ind2sub faster when possible
614-
_ind2sub_rs(axp, mi::Tuple, i) = Base.ind2sub_rs(axp, mi, i)
615-
_ind2sub_rs(axp, mi::Tuple{}, i) = Base._ind2sub(axp, i)
616-
_ind2sub_rs(axp::Base.Indices{1}, mi::Tuple{}, i) = (first(axp[1]) + i - 1,)
617-
@inline simd_outer_range(iter::CartesianPartition{<:CartesianIndex{0}}) = ((0,1,CartesianIndex()),)
613+
# use to ReshapedArray's index machinery when possible
614+
@inline _ind2sub_rs(ax::Tuple, mi::Tuple, i::Int) = Base.ind2sub_rs(ax, mi, Base.offset_if_vec(i, ax))
615+
@inline _ind2sub_rs(ax::Tuple{Any,Any,Vararg{Any}}, mi::Tuple{}, i::Int) = Base._ind2sub(ax, i)
616+
@inline function _splitlinear(rs::ReshapedArray, i::Int)
617+
axp = axes(rs.parent)
618+
ci = _ind2sub_rs(axp, rs.mi, i)
619+
length(ci) == 2 ? ci : ci[1], Base._sub2ind(tail(axp), tail(ci)...)
620+
end
621+
simd_outer_range(iter::CartesianPartition{CartesianIndex{0}}) = ((0,1,iter[]),)
618622
@inline function simd_outer_range(iter::CartesianPartition)
619623
# CartesianPartition might start and stop in the middle of the outer
620624
# dimensions, thus the outer range itself is a CartesianPartition.
621-
piter = iter.parent.parent
622-
ax1, oiter = split(piter, Val(1))
623-
vindʷ = only(iter.indices)
624-
function _splitlinear(i::Int)
625-
ci = _ind2sub_rs(axes(piter), iter.parent.mi, i)
626-
ci[1], Base._to_linear_index(oiter, tail(ci)...)
625+
rs = iter.parent
626+
ax1, oiter = split(rs.parent, Val(1))
627+
fl, vl = _splitlinear(rs, first(iter.indices[1]))
628+
fr, vr = _splitlinear(rs, last(iter.indices[1]))
629+
outer = @inbounds if ndims(oiter) == 1
630+
CartesianIndices((x.indices[1][vl:vr],))
631+
else
632+
view(ReshapedArray(x, (length(x),), ()), vl:vr)
627633
end
628-
fl, vl = _splitlinear(first(vindʷ))
629-
fr, vr = _splitlinear(last(vindʷ))
630-
# we dont have #40344 for 1.6 and 1.7, force this return a CartesianIndices{1}
631-
_view(x::CartesianIndices{1}, ind::UnitRange) = @inbounds CartesianIndices((x.indices[1][ind],))
632-
# there's no need to make outer range fast-indexable
633-
_view(x::CartesianIndices, ind::UnitRange) = @inbounds view(ReshapedArray(x, (length(x),), ()), ind)
634-
outer = _view(oiter, vl:vr)
635634
# Use Generator to make inner loop branchless
636635
@inline function genouter(i::Int, I::CartesianIndex)
637636
l = i == 1 ? fl : firstindex(ax1)

test/iterators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ end
556556
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
557557
end
558558
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), iter))
559+
# test for slowindex cases
559560
iter′ = Base.ReshapedArray(iter, (length(iter),), ())
560561
P′ = partition(iter′, part)
561562
for I in P′

0 commit comments

Comments
 (0)