Skip to content

Commit e95839f

Browse files
Optimize reverse(::CartesianIndices, dims=...) (JuliaLang#49112)
* Optimize reverse(::CartesianIndices, dims=...) Optimize reverse(::CartesianIndices, dims=...) Correct tests and add stability test Remove whitespaces in reversed CartesianIndices Make funcs constpropagable and refactor Add tests for reverse(CartesianIndices; dims=:) * Typo * fix ambiguity and const-propagation * Fix empty `dims` --------- Co-authored-by: N5N3 <[email protected]>
1 parent f79fdf9 commit e95839f

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

base/multidimensional.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,30 @@ module IteratorsMD
509509
end
510510

511511
# reversed CartesianIndices iteration
512+
@inline function Base._reverse(iter::CartesianIndices, ::Colon)
513+
CartesianIndices(reverse.(iter.indices))
514+
end
515+
516+
Base.@constprop :aggressive function Base._reverse(iter::CartesianIndices, dim::Integer)
517+
1 <= dim <= ndims(iter) || throw(ArgumentError(Base.LazyString("invalid dimension ", dim, " in reverse")))
518+
ndims(iter) == 1 && return Base._reverse(iter, :)
519+
indices = iter.indices
520+
return CartesianIndices(Base.setindex(indices, reverse(indices[dim]), dim))
521+
end
522+
523+
Base.@constprop :aggressive function Base._reverse(iter::CartesianIndices, dims::Tuple{Vararg{Integer}})
524+
indices = iter.indices
525+
# use `sum` to force const fold
526+
dimrev = ntuple(i -> sum(==(i), dims; init = 0) == 1, Val(length(indices)))
527+
length(dims) == sum(dimrev) || throw(ArgumentError(Base.LazyString("invalid dimensions ", dims, " in reverse")))
528+
length(dims) == length(indices) && return Base._reverse(iter, :)
529+
indices′ = map((i, f) -> f ? reverse(i) : i, indices, dimrev)
530+
return CartesianIndices(indices′)
531+
end
512532

513-
Base.reverse(iter::CartesianIndices) = CartesianIndices(reverse.(iter.indices))
533+
# fix ambiguity with array.jl:
534+
Base._reverse(iter::CartesianIndices{1}, dims::Tuple{Integer}) =
535+
Base._reverse(iter, first(dims))
514536

515537
@inline function iterate(r::Reverse{<:CartesianIndices})
516538
iterfirst = last(r.itr)

test/arrayops.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,39 @@ end
17051705
@test istriu([1 2 0; 0 4 1])
17061706
end
17071707

1708+
#issue 49021
1709+
@testset "reverse cartesian indices" begin
1710+
@test reverse(CartesianIndices((2, 3))) === CartesianIndices((2:-1:1, 3:-1:1))
1711+
@test reverse(CartesianIndices((2:5, 3:7))) === CartesianIndices((5:-1:2, 7:-1:3))
1712+
@test reverse(CartesianIndices((5:-1:2, 7:-1:3))) === CartesianIndices((2:1:5, 3:1:7))
1713+
end
1714+
1715+
@testset "reverse cartesian indices dim" begin
1716+
A = CartesianIndices((2, 3, 5:-1:1))
1717+
@test reverse(A, dims=1) === CartesianIndices((2:-1:1, 3, 5:-1:1))
1718+
@test reverse(A, dims=3) === CartesianIndices((2, 3, 1:1:5))
1719+
@test_throws ArgumentError reverse(A, dims=0)
1720+
@test_throws ArgumentError reverse(A, dims=4)
1721+
end
1722+
1723+
@testset "reverse cartesian indices multiple dims" begin
1724+
A = CartesianIndices((2, 3, 5:-1:1))
1725+
@test reverse(A, dims=(1, 3)) === CartesianIndices((2:-1:1, 3, 1:1:5))
1726+
@test reverse(A, dims=(3, 1)) === CartesianIndices((2:-1:1, 3, 1:1:5))
1727+
@test_throws ArgumentError reverse(A, dims=(1, 2, 4))
1728+
@test_throws ArgumentError reverse(A, dims=(0, 1, 2))
1729+
@test_throws ArgumentError reverse(A, dims=(1, 1))
1730+
end
1731+
1732+
@testset "stability of const propagation" begin
1733+
A = CartesianIndices((2, 3, 5:-1:1))
1734+
f1(x) = reverse(x; dims=1)
1735+
f2(x) = reverse(x; dims=(1, 3))
1736+
@test @inferred(f1(A)) === CartesianIndices((2:-1:1, 3, 5:-1:1))
1737+
@test @inferred(f2(A)) === CartesianIndices((2:-1:1, 3, 1:1:5))
1738+
@test @inferred(reverse(A; dims=())) === A
1739+
end
1740+
17081741
# issue 4228
17091742
let A = [[i i; i i] for i=1:2]
17101743
@test cumsum(A) == Any[[1 1; 1 1], [3 3; 3 3]]

0 commit comments

Comments
 (0)