@@ -140,8 +140,34 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
140140 return
141141end
142142
143+ function big_mapreduce_kernel (f, op, neutral, :: Val{Rreduce} , :: Val{Rother} , R, As) where {Rreduce, Rother}
144+ grid_idx = thread_position_in_threadgroup_1d () + (threadgroup_position_in_grid_1d () - 1 u32) * threadgroups_per_grid_1d ()
145+
146+ @inbounds if grid_idx <= length (Rother)
147+ Iother = Rother[grid_idx]
148+
149+ # load the neutral value
150+ neutral = if neutral === nothing
151+ R[Iother]
152+ else
153+ neutral
154+ end
155+
156+ val = op (neutral, neutral)
157+
158+ Ibegin = Rreduce[1 ]
159+ for Ireduce in Rreduce
160+ val = op (val, f (As[Iother + Ireduce - Ibegin]))
161+ end
162+ R[Iother] = val
163+ end
164+ return
165+ end
166+
143167# # COV_EXCL_STOP
144168
169+ _big_mapreduce_threshold (dev) = dev. maxThreadsPerThreadgroup. width * num_gpu_cores ()
170+
145171function GPUArrays. mapreducedim! (f:: F , op:: OP , R:: WrappedMtlArray{T} ,
146172 A:: Union{AbstractArray,Broadcast.Broadcasted} ;
147173 init= nothing ) where {F, OP, T}
@@ -165,6 +191,15 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
165191 # NOTE: we hard-code `OneTo` (`first.(axes(A))` would work too) or we get a
166192 # CartesianIndices object with UnitRanges that behave badly on the GPU.
167193 @assert length (Rall) == length (Rother) * length (Rreduce)
194+ @assert length (Rother) > 0
195+
196+ # If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
197+ if length (Rother) >= _big_mapreduce_threshold (device (R))
198+ threads = min (length (Rreduce), 512 )
199+ groups = cld (length (Rother), threads)
200+ kernel = @metal threads groups big_mapreduce_kernel (f, op, init, Val (Rreduce), Val (Rother), R, A)
201+ return R
202+ end
168203
169204 # when the reduction dimension is contiguous in memory, we can improve performance
170205 # by having each thread read multiple consecutive elements. base on experiments,
0 commit comments