Skip to content

Commit 2a5cfee

Browse files
N5N3mbauman
andauthored
Some fix for broadcast with offset axes. (#43414)
Co-authored-by: Matt Bauman <[email protected]>
1 parent 9f9d3ac commit 2a5cfee

File tree

2 files changed

+42
-44
lines changed

2 files changed

+42
-44
lines changed

base/broadcast.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,15 +585,17 @@ an `Int`.
585585
"""
586586
Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = to_index(_newindex(axes(arg), I.I))
587587
Base.@propagate_inbounds newindex(arg, I::Integer) = to_index(_newindex(axes(arg), (I,)))
588-
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple) = (ifelse(length(ax[1]) == 1, ax[1][1], I[1]), _newindex(tail(ax), tail(I))...)
588+
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple) = (ifelse(length(ax[1]) == 1, ax[1][begin], I[1]), _newindex(tail(ax), tail(I))...)
589589
Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple) = ()
590-
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple{}) = (ax[1][1], _newindex(tail(ax), ())...)
590+
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple{}) = (ax[1][begin], _newindex(tail(ax), ())...)
591591
Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()
592592

593593
# If dot-broadcasting were already defined, this would be `ifelse.(keep, I, Idefault)`.
594594
@inline newindex(I::CartesianIndex, keep, Idefault) = to_index(_newindex(I.I, keep, Idefault))
595-
@inline newindex(i::Integer, keep::Tuple, idefault) = ifelse(keep[1], i, idefault[1])
596-
@inline newindex(i::Integer, keep::Tuple{}, idefault) = CartesianIndex(())
595+
@inline newindex(I::CartesianIndex{1}, keep, Idefault) = newindex(I.I[1], keep, Idefault)
596+
@inline newindex(i::Integer, keep::Tuple, idefault) = CartesianIndex(ifelse(keep[1], Int(i), Int(idefault[1])), idefault[2])
597+
@inline newindex(i::Integer, keep::Tuple{Bool}, idefault) = ifelse(keep[1], i, idefault[1])
598+
@inline newindex(i::Integer, keep::Tuple{}, idefault) = CartesianIndex()
597599
@inline _newindex(I, keep, Idefault) =
598600
(ifelse(keep[1], I[1], Idefault[1]), _newindex(tail(I), tail(keep), tail(Idefault))...)
599601
@inline _newindex(I, keep::Tuple{}, Idefault) = () # truncate if keep is shorter than I

test/broadcast.jl

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ ci(x) = CartesianIndex(x)
5353
@test @inferred(newindex(ci((2,2)), (true,), (-1,))) == 2
5454
@test @inferred(newindex(ci((2,2)), (false,), (-1,))) == -1
5555
@test @inferred(newindex(ci((2,2)), (), ())) == ci(())
56+
@test @inferred(newindex(ci((2,)), (true, false, false), (-1, -1, -1))) == ci((2, -1))
5657

5758
end
5859

@@ -853,29 +854,34 @@ let
853854
@test Dict(c .=> d) == Dict("foo" => 1, "bar" => 2)
854855
end
855856

856-
# Broadcasted iterable/indexable APIs
857-
let
858-
bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5))
859-
@test IndexStyle(bc) == IndexLinear()
860-
@test eachindex(bc) === Base.OneTo(5)
861-
@test length(bc) === 5
862-
@test ndims(bc) === 1
863-
@test ndims(typeof(bc)) === 1
864-
@test bc[1] === bc[CartesianIndex((1,))] === 5.0
865-
@test copy(bc) == [v for v in bc] == collect(bc)
866-
@test eltype(copy(bc)) == eltype([v for v in bc]) == eltype(collect(bc))
867-
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
868-
869-
bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5*ones(1, 4)))
870-
@test IndexStyle(bc) == IndexCartesian()
871-
@test eachindex(bc) === CartesianIndices((Base.OneTo(5), Base.OneTo(4)))
872-
@test length(bc) === 20
873-
@test ndims(bc) === 2
874-
@test ndims(typeof(bc)) === 2
875-
@test bc[1,1] == bc[CartesianIndex((1,1))] === 5.0
876-
@test copy(bc) == [v for v in bc] == collect(bc)
877-
@test eltype(copy(bc)) == eltype([v for v in bc]) == eltype(collect(bc))
878-
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
857+
isdefined(Main, :OffsetArrays) || @eval Main include("testhelpers/OffsetArrays.jl")
858+
using .Main.OffsetArrays
859+
@testset "Broadcasted iterable/indexable APIs" begin
860+
for f in (identity, x -> OffsetArray(x, ntuple(Returns(-1), ndims(x))))
861+
a = f(zeros(5))
862+
bc = Broadcast.instantiate(Broadcast.broadcasted(+, a, 5))
863+
@test IndexStyle(bc) == IndexLinear()
864+
@test eachindex(bc) === eachindex(a)
865+
@test length(bc) === 5
866+
@test ndims(bc) === 1
867+
@test ndims(typeof(bc)) === 1
868+
@test bc[1] === bc[CartesianIndex((1,))] === 5.0
869+
@test copy(bc) == [v for v in bc] == collect(bc)
870+
@test eltype(copy(bc)) == eltype([v for v in bc]) == eltype(collect(bc))
871+
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
872+
873+
b = f(5*ones(1, 4))
874+
bc = Broadcast.instantiate(Broadcast.broadcasted(+, a, b))
875+
@test IndexStyle(bc) == IndexCartesian()
876+
@test eachindex(bc) === CartesianIndices((axes(a, 1), axes(b, 2)))
877+
@test length(bc) === 20
878+
@test ndims(bc) === 2
879+
@test ndims(typeof(bc)) === 2
880+
@test bc[1,1] == bc[CartesianIndex((1,1))] === 5.0
881+
@test copy(bc) == [v for v in bc] == collect(bc)
882+
@test eltype(copy(bc)) == eltype([v for v in bc]) == eltype(collect(bc))
883+
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
884+
end
879885

880886
struct MyFill{T,N} <: AbstractArray{T,N}
881887
val :: T
@@ -1118,24 +1124,14 @@ end
11181124
end
11191125

11201126
@testset "inplace broadcast with trailing singleton dims" begin
1121-
for (a, b, c) in (([1, 2], reshape([3 4], :, 1), reshape([5, 6], :, 1, 1)),
1127+
for (a_, b_, c_) in (([1, 2], reshape([3 4], :, 1), reshape([5, 6], :, 1, 1)),
11221128
([1 2; 3 4], reshape([5 6; 7 8], 2, 2, 1), reshape([9 10; 11 12], 2, 2, 1, 1)))
1123-
1124-
a_ = copy(a)
1125-
a_ .= b
1126-
@test a_ == dropdims(b, dims=(findall(==(1), size(b))...,))
1127-
1128-
a_ = copy(a)
1129-
a_ .= b
1130-
@test a_ == dropdims(b, dims=(findall(==(1), size(b))...,))
1131-
1132-
a_ = copy(a)
1133-
a_ .= b .+ c
1134-
@test a_ == dropdims(b .+ c, dims=(findall(==(1), size(c))...,))
1135-
1136-
a_ = copy(a)
1137-
a_ .*= c
1138-
@test a_ == dropdims(a .* c, dims=(findall(==(1), size(c))...,))
1129+
for fun in (x -> OffsetArray(x, ntuple(Returns(1), ndims(x))), identity)
1130+
a, b, c = fun(a_), fun(b_), fun(c_)
1131+
@test (deepcopy(a) .= b) == dropdims(b, dims=(findall(==(1), size(b))...,))
1132+
@test (deepcopy(a) .= b .+ c) == dropdims(b .+ c, dims=(findall(==(1), size(c))...,))
1133+
@test (deepcopy(a) .*= c) == dropdims(a .* c, dims=(findall(==(1), size(c))...,))
1134+
end
11391135
end
11401136
end
11411137

0 commit comments

Comments
 (0)