2020@inline  function  reduce_group (op, val:: T , neutral, shuffle:: Val{true} , :: Val{maxthreads} ) where  {T, maxthreads}
2121    #  shared mem for partial sums
2222    assume (threads_per_simdgroup () ==  32 )
23-     shared =  MtlThreadGroupArray (T, 32 )
23+     shared =  KI . localmemory (T, 32 )
2424
2525    wid  =  simdgroup_index_in_threadgroup ()
2626    lane =  thread_index_in_simdgroup ()
3434    end 
3535
3636    #  wait for all partial reductions
37-     threadgroup_barrier (MemoryFlagThreadGroup )
37+     KI . barrier ( )
3838
3939    #  read from shared memory only if that warp existed
40-     val =  if  thread_index_in_threadgroup ()  <=  fld1 (threads_per_threadgroup (). x, 32 )
40+     val =  if  KI . get_local_id () . x  <=  fld1 (KI . get_local_size (). x, 32 )
4141        @inbounds  shared[lane]
4242    else 
4343        neutral
5252
5353#  Reduce a value across a group, using local memory for communication
5454@inline  function  reduce_group (op, val:: T , neutral, shuffle:: Val{false} , :: Val{maxthreads} ) where  {T, maxthreads}
55-     threads =  threads_per_threadgroup (). x
56-     thread =  thread_position_in_threadgroup (). x
55+     threads =  KI . get_local_size (). x
56+     thread =  KI . get_local_id (). x
5757
5858    #  local mem for a complete reduction
59-     shared =  MtlThreadGroupArray (T, (maxthreads,))
59+     shared =  KI . localmemory (T, (maxthreads,))
6060    @inbounds  shared[thread] =  val
6161
6262    #  perform a reduction
6363    d =  1 
6464    while  d <  threads
65-         threadgroup_barrier (MemoryFlagThreadGroup )
65+         KI . barrier ( )
6666        index =  2  *  d *  (thread- 1 ) +  1 
6767        @inbounds  if  index <=  threads
6868            other_val =  if  index +  d <=  threads
@@ -94,9 +94,9 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
9494    :: Val{Rother} , :: Val{Rlen} , :: Val{grain} , shuffle, R, As... ) where  {Rreduce, Rother, Rlen, grain}
9595    #  decompose the 1D hardware indices into separate ones for reduction (across items
9696    #  and possibly groups if it doesn't fit) and other elements (remaining groups)
97-     localIdx_reduce =  thread_position_in_threadgroup (). x
98-     localDim_reduce =  threads_per_threadgroup (). x *  grain
99-     groupIdx_reduce, groupIdx_other =  fldmod1 (threadgroup_position_in_grid (). x, Rlen)
97+     localIdx_reduce =  KI . get_local_id (). x
98+     localDim_reduce =  KI . get_local_size (). x *  grain
99+     groupIdx_reduce, groupIdx_other =  fldmod1 (KI . get_group_id (). x, Rlen)
100100
101101    #  group-based indexing into the values outside of the reduction dimension
102102    #  (that means we can safely synchronize items within this group)
@@ -141,7 +141,7 @@ function partial_mapreduce_device(f, op, neutral, maxthreads, ::Val{Rreduce},
141141end 
142142
143143function  serial_mapreduce_kernel (f, op, neutral, :: Val{Rreduce} , :: Val{Rother} , R, As) where  {Rreduce, Rother}
144-     grid_idx =  thread_position_in_grid (). x
144+     grid_idx =  KI . get_global_id (). x
145145
146146    @inbounds  if  grid_idx <=  length (Rother)
147147        Iother =  Rother[grid_idx]
@@ -166,11 +166,12 @@ end
166166
167167# # COV_EXCL_STOP
168168
169- serial_mapreduce_threshold (dev) =  dev . maxThreadsPerThreadgroup . width  *  num_gpu_cores ( )
169+ serial_mapreduce_threshold (dev) =  KI . max_work_group_size ( MetalBackend ())  *  KI . multiprocessor_count ( MetalBackend () )
170170
171171function  GPUArrays. mapreducedim! (f:: F , op:: OP , R:: WrappedMtlArray{T} ,
172172                                 A:: Union{AbstractArray,Broadcast.Broadcasted} ;
173173                                 init= nothing ) where  {F, OP, T}
174+     backend =  MetalBackend ()
174175    Base. check_reducedims (R, A)
175176    length (A) ==  0  &&  return  R #  isempty(::Broadcasted) iterates
176177
@@ -195,10 +196,10 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
195196
196197    #  If `Rother` is large enough, then a naive loop is more efficient than partial reductions.
197198    if  length (Rother) >=  serial_mapreduce_threshold (device (R))
198-         kernel =  @metal  launch = false   serial_mapreduce_kernel ( f, op, init, Val (Rreduce), Val (Rother), R, A)
199-         threads =  min ( length (Rother), kernel . pipeline . maxTotalThreadsPerThreadgroup )
199+         kernel =  KI . KIKernel (backend,  serial_mapreduce_kernel,  f, op, init, Val (Rreduce), Val (Rother), R, A)
200+         threads =  KI . kernel_max_work_group_size (backend, kernel; max_work_items = length (Rother))
200201        groups =  cld (length (Rother), threads)
201-         kernel (f, op, init, Val (Rreduce), Val (Rother), R, A; threads, groups )
202+         kernel (f, op, init, Val (Rreduce), Val (Rother), R, A; numworkgroups = groups, workgroupsize = threads )
202203        return  R
203204    end 
204205
@@ -223,17 +224,17 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
223224    #  we might not be able to launch all those threads to reduce each slice in one go.
224225    #  that's why each threads also loops across their inputs, processing multiple values
225226    #  so that we can span the entire reduction dimension using a single item group.
226-     kernel =  @metal  launch = false   partial_mapreduce_device ( f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
227+     kernel =  KI . KIKernel (backend,  partial_mapreduce_device,  f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
227228                                                          Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R, A)
228229
229230    #  how many threads do we want?
230231    # 
231232    #  threads in a group work together to reduce values across the reduction dimensions;
232233    #  we want as many as possible to improve algorithm efficiency and execution occupancy.
233-     wanted_threads =  shuffle ?  nextwarp (kernel. pipeline, length (Rreduce)) :  length (Rreduce)
234+     wanted_threads =  shuffle ?  nextwarp (kernel. kern . pipeline, length (Rreduce)) :  length (Rreduce)
234235    function  compute_threads (max_threads)
235236        if  wanted_threads >  max_threads
236-             shuffle ?  prevwarp (kernel. pipeline, max_threads) :  max_threads
237+             shuffle ?  prevwarp (kernel. kern . pipeline, max_threads) :  max_threads
237238        else 
238239            wanted_threads
239240        end 
@@ -243,7 +244,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
243244    #          kernel above may be greater than the maxTotalThreadsPerThreadgroup of the eventually launched
244245    #          kernel below, causing errors
245246    #  reduce_threads = compute_threads(kernel.pipeline.maxTotalThreadsPerThreadgroup)
246-     reduce_threads =  compute_threads (512 )
247+     reduce_threads =  compute_threads (KI . kernel_max_work_group_size (backend, kernel) )
247248
248249    #  how many groups should we launch?
249250    # 
@@ -262,7 +263,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
262263        #  we can cover the dimensions to reduce using a single group
263264        kernel (f, op, init, Val (maxthreads), Val (Rreduce), Val (Rother),
264265               Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), R, A;
265-                threads,  groups)
266+                 numworkgroups = groups, workgroupsize = threads )
266267    else 
267268        #  we need multiple steps to cover all values to reduce
268269        partial =  similar (R, (size (R)... , reduce_groups))
@@ -273,9 +274,12 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::WrappedMtlArray{T},
273274        end 
274275        #  NOTE: we can't use the previously-compiled kernel, since the type of `partial`
275276        #        might not match the original output container (e.g. if that was a view).
276-         @metal  threads groups  partial_mapreduce_device ( 
277+         KI . KIKernel (backend,  partial_mapreduce_device, 
277278            f, op, init, Val (threads), Val (Rreduce), Val (Rother),
278-             Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A)
279+             Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A)(
280+             f, op, init, Val (threads), Val (Rreduce), Val (Rother),
281+             Val (UInt64 (length (Rother))), Val (grain), Val (shuffle), partial, A;
282+             numworkgroups= groups, workgroupsize= threads)
279283
280284        GPUArrays. mapreducedim! (identity, op, R, partial; init= init)
281285    end 
0 commit comments