@@ -31,19 +31,6 @@ Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::Abs
3131 dims= :, init= nothing ) = _mapreduce(f, op, A, As... ; dims= dims, init= init)
3232
3333function _mapreduce(f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
34- # mapreduce should apply `f` like `map` does, consuming elements like iterators
35- bc = if allequal(size.(As). .. )
36- Broadcast. instantiate(Broadcast. broadcasted(f, As... ))
37- else
38- # TODO : can we avoid the reshape + view?
39- indices = LinearIndices.(As)
40- common_length = minimum(length.(indices))
41- Bs = map(As) do A
42- view(reshape(A, length(A)), 1 : common_length)
43- end
44- Broadcast. instantiate(Broadcast. broadcasted(f, Bs... ))
45- end
46-
4734 # figure out the destination container type by looking at the initializer element,
4835 # or by relying on inference to reason through the map and reduce functions
4936 if init === nothing
@@ -57,16 +44,39 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
5744 ET = typeof(init)
5845 end
5946
60- sz = size(bc)
47+ # apply the mapping function to the input arrays
48+ if N == 1
49+ # ... with only a single input, we can defer this to the reduce step
50+ A = only(As)
51+ else
52+ # mapreduce should apply `f` like `map` does, consuming elements like iterators
53+ A = if allequal(size.(As). .. )
54+ Broadcast. instantiate(Broadcast. broadcasted(f, As... ))
55+ else
56+ # TODO : can we avoid the reshape + view?
57+ indices = LinearIndices.(As)
58+ common_length = minimum(length.(indices))
59+ Bs = map(As) do A
60+ view(reshape(A, length(A)), 1 : common_length)
61+ end
62+ Broadcast. instantiate(Broadcast. broadcasted(f, Bs... ))
63+ end
64+ f = identity
65+ end
66+
67+ # allocate an output container
68+ sz = size(A)
6169 red = ntuple(i-> (dims== Colon() || i in dims) ? 1 : sz[i], length(sz))
62- R = similar(bc , ET, red)
70+ R = similar(A , ET, red)
6371
72+ # perform the reduction
6473 if prod(sz) == 0
6574 fill!(R, init)
6675 else
67- mapreducedim!(identity , op, R, bc; init = init)
76+ mapreducedim!(f , op, R, A; init)
6877 end
6978
79+ # return the result
7080 if dims === Colon()
7181 @allowscalar R[]
7282 else
0 commit comments