Skip to content

Commit f979ee9

Browse files
authored
Fix sum(bc::Broadcasted; dims = 1, init = 0) (#43736)
This PR make `has_fast_linear_indexing` rely on `IndexStyle`/`ndims` to fix `mapreduce` for `Broadcasted` with `dim > 1`. Before: ```julia julia> a = randn(100,100); julia> bc = Broadcast.instantiate(Base.broadcasted(+,a,a)); julia> sum(bc,dims = 1,init = 0.0) == sum(collect(bc), dims = 1) ERROR: MethodError: no method matching LinearIndices(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Matrix{Float64}, Matrix{Float64}}}) ``` After: ```julia julia> sum(bc,dims = 1,init = 0.0) == sum(collect(bc), dims = 1) true ``` This should extend the optimized fallback to more `AbstractArray`. (e.g. `SubArray`) Test added.
1 parent 20b3af3 commit f979ee9

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

base/reducedim.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,8 @@ end
196196

197197
## generic (map)reduction
198198

199-
has_fast_linear_indexing(a::AbstractArrayOrBroadcasted) = false
200-
has_fast_linear_indexing(a::Array) = true
201-
has_fast_linear_indexing(::Union{Number,Ref,AbstractChar}) = true # 0d objects, for Broadcasted
202-
has_fast_linear_indexing(bc::Broadcast.Broadcasted) =
203-
all(has_fast_linear_indexing, bc.args)
199+
has_fast_linear_indexing(a::AbstractArrayOrBroadcasted) = IndexStyle(a) === IndexLinear()
200+
has_fast_linear_indexing(a::AbstractVector) = true
204201

205202
function check_reducedims(R, A)
206203
# Check whether R has compatible dimensions w.r.t. A for reduction

test/broadcast.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,10 @@ end
979979
@test sum(bc, dims=1, init=0) == [5]
980980
bc = Broadcast.instantiate(Broadcast.broadcasted(*, ['a','b'], 'c'))
981981
@test prod(bc, dims=1, init="") == ["acbc"]
982+
983+
a = rand(-10:10,32,4); b = rand(-10:10,32,4)
984+
bc = Broadcast.instantiate(Broadcast.broadcasted(+,a,b))
985+
@test sum(bc; dims = 1, init = 0.0) == sum(collect(bc); dims = 1, init = 0.0)
982986
end
983987

984988
# treat Pair as scalar:

0 commit comments

Comments
 (0)