@@ -140,8 +140,8 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
140140 return
141141end
142142
143- function big_mapreduce_kernel (f, op, neutral, :: Val{Rreduce} , :: Val{Rother} , R, As) where {Rreduce, Rother}
144- grid_idx = thread_position_in_threadgroup_1d () + ( threadgroup_position_in_grid_1d () - 1 u32) * threadgroups_per_grid_1d ()
143+ function serial_mapreduce_kernel (f, op, neutral, :: Val{Rreduce} , :: Val{Rother} , R, As) where {Rreduce, Rother}
144+ grid_idx = thread_position_in_grid_1d ()
145145
146146 @inbounds if grid_idx <= length (Rother)
147147 Iother = Rother[grid_idx]
166166
167167# # COV_EXCL_STOP
168168
169- _big_mapreduce_threshold (dev) = dev. maxThreadsPerThreadgroup. width * num_gpu_cores ()
169+ serial_mapreduce_threshold (dev) = dev. maxThreadsPerThreadgroup. width * num_gpu_cores ()
170170
171171function GPUArrays. mapreducedim! (f:: F , op:: OP , R:: WrappedMtlArray{T} ,
172172 A:: Union{AbstractArray,Broadcast.Broadcasted} ;
@@ -194,10 +194,11 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
194194 @assert length (Rother) > 0
195195
196196 # If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
197- if length (Rother) >= _big_mapreduce_threshold (device (R))
198- threads = min (length (Rreduce), 512 )
197+ if length (Rother) >= serial_mapreduce_threshold (device (R))
198+ kernel = @metal launch= false serial_mapreduce_kernel (f, op, init, Val (Rreduce), Val (Rother), R, A)
199+ threads = min (length (Rother), kernel. pipeline. maxTotalThreadsPerThreadgroup)
199200 groups = cld (length (Rother), threads)
200- kernel = @metal threads groups big_mapreduce_kernel (f, op, init, Val (Rreduce), Val (Rother), R, A)
201+ kernel (f, op, init, Val (Rreduce), Val (Rother), R, A; threads, groups )
201202 return R
202203 end
203204
0 commit comments