@@ -209,17 +209,32 @@ end
209209
210210function JACC. parallel_reduce (
211211 :: AMDGPUBackend , N:: Integer , op, f:: Function , x... ; init)
212- numThreads = 512
213- threads = numThreads
214- blocks = ceil (Int, N / threads)
215- ret = fill! (AMDGPU. ROCArray {typeof(init)} (undef, blocks), init)
212+ ret_inst = AMDGPU. ROCArray {typeof(init)} (undef, 0 )
213+ kernel1 = @roc launch= false _parallel_reduce_amdgpu (
214+ N, op, ret_inst, f, x... )
215+ config1 = AMDGPU. launch_configuration (kernel1)
216+ threads1 = config1. groupsize
217+
216218 rret = AMDGPU. ROCArray ([init])
217- @roc groupsize= threads gridsize= blocks _parallel_reduce_amdgpu (
218- N, op, ret, f, x... )
219+ kernel2 = @roc launch= false reduce_kernel_amdgpu (1 , op, ret_inst, rret)
220+ config2 = AMDGPU. launch_configuration (kernel2)
221+ threads2 = config2. groupsize
222+
223+ threads = min (threads1, threads2, 512 )
224+ blocks = cld (N, threads)
225+
226+ shmem_size = threads * sizeof (Float64)
227+
228+ ret = fill! (AMDGPU. ROCArray {typeof(init)} (undef, blocks), init)
229+
230+ kernel1 (
231+ N, op, ret, f, x... ; groupsize = threads,
232+ gridsize = blocks, shmem = shmem_size)
219233 AMDGPU. synchronize ()
220- @roc groupsize= threads gridsize = 1 reduce_kernel_amdgpu (
221- blocks, op, ret, rret )
234+ kernel2 (blocks, op, ret, rret; groupsize = threads,
235+ gridsize = 1 , shmem = shmem_size )
222236 AMDGPU. synchronize ()
237+
223238 return Base. Array (rret)[]
224239end
225240
@@ -268,7 +283,8 @@ function _parallel_for_amdgpu_LMN((L, M, N), f, x...)
268283end
269284
270285function _parallel_reduce_amdgpu (N, op, ret, f, x... )
271- shared_mem = @ROCStaticLocalArray (eltype (ret), 512 )
286+ shmem_length = workgroupDim (). x
287+ shared_mem = @ROCDynamicLocalArray (eltype (ret), shmem_length)
272288 i = (workgroupIdx (). x - 1 ) * workgroupDim (). x + workitemIdx (). x
273289 ti = workitemIdx (). x
274290 shared_mem[ti] = ret[workgroupIdx (). x]
@@ -278,97 +294,53 @@ function _parallel_reduce_amdgpu(N, op, ret, f, x...)
278294 shared_mem[ti] = tmp
279295 end
280296 AMDGPU. sync_workgroup ()
281- if (ti <= 256 )
282- shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 256 ])
283- end
284- AMDGPU. sync_workgroup ()
285- if (ti <= 128 )
286- shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 128 ])
287- end
288- AMDGPU. sync_workgroup ()
289- if (ti <= 64 )
290- shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 64 ])
291- end
292- AMDGPU. sync_workgroup ()
293- if (ti <= 32 )
294- shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 32 ])
295- end
296- AMDGPU. sync_workgroup ()
297- if (ti <= 16 )
298- shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 16 ])
299- end
300- AMDGPU. sync_workgroup ()
301- if (ti <= 8 )
302- shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 8 ])
303- end
304- AMDGPU. sync_workgroup ()
305- if (ti <= 4 )
306- shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 4 ])
307- end
308- AMDGPU. sync_workgroup ()
309- if (ti <= 2 )
310- shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 2 ])
297+ tn = div (shmem_length, 2 )
298+ while tn > 1
299+ if ti <= tn
300+ shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + tn])
301+ end
302+ AMDGPU. sync_workgroup ()
303+ tn = div (tn, 2 )
311304 end
312- AMDGPU. sync_workgroup ()
313- if (ti == 1 )
305+ if ti == 1
314306 shared_mem[ti] = op (shared_mem[ti], shared_mem[ti + 1 ])
315307 ret[workgroupIdx (). x] = shared_mem[ti]
316308 end
317- AMDGPU . sync_workgroup ()
309+
318310 return nothing
319311end
320312
321313function reduce_kernel_amdgpu (N, op, red, ret)
322- shared_mem = @ROCStaticLocalArray (eltype (ret), 512 )
314+ shmem_length = workgroupDim (). x
315+ shared_mem = @ROCDynamicLocalArray (eltype (ret), shmem_length)
323316 i = workitemIdx (). x
324317 ii = i
325318 tmp = ret[1 ]
326- if N > 512
319+ if N > shmem_length
327320 while ii <= N
328321 tmp = op (tmp, @inbounds red[ii])
329- ii += 512
322+ ii += shmem_length
330323 end
331324 elseif (i <= N)
332325 tmp = @inbounds red[i]
333326 end
334327 shared_mem[i] = tmp
335328 AMDGPU. sync_workgroup ()
336- if (i <= 256 )
337- shared_mem[i] = op (shared_mem[i], shared_mem[i + 256 ])
338- end
339- AMDGPU. sync_workgroup ()
340- if (i <= 128 )
341- shared_mem[i] = op (shared_mem[i], shared_mem[i + 128 ])
342- end
343- AMDGPU. sync_workgroup ()
344- if (i <= 64 )
345- shared_mem[i] = op (shared_mem[i], shared_mem[i + 64 ])
346- end
347- AMDGPU. sync_workgroup ()
348- if (i <= 32 )
349- shared_mem[i] = op (shared_mem[i], shared_mem[i + 32 ])
350- end
351- AMDGPU. sync_workgroup ()
352- if (i <= 16 )
353- shared_mem[i] = op (shared_mem[i], shared_mem[i + 16 ])
354- end
355- AMDGPU. sync_workgroup ()
356- if (i <= 8 )
357- shared_mem[i] = op (shared_mem[i], shared_mem[i + 8 ])
358- end
359- AMDGPU. sync_workgroup ()
360- if (i <= 4 )
361- shared_mem[i] = op (shared_mem[i], shared_mem[i + 4 ])
362- end
363- AMDGPU. sync_workgroup ()
364- if (i <= 2 )
365- shared_mem[i] = op (shared_mem[i], shared_mem[i + 2 ])
329+
330+ tn = div (shmem_length, 2 )
331+ while tn > 1
332+ if i <= tn
333+ shared_mem[i] = op (shared_mem[i], shared_mem[i + tn])
334+ end
335+ AMDGPU. sync_workgroup ()
336+ tn = div (tn, 2 )
366337 end
367- AMDGPU . sync_workgroup ()
368- if ( i == 1 )
338+
339+ if i == 1
369340 shared_mem[i] = op (shared_mem[i], shared_mem[i + 1 ])
370341 ret[1 ] = shared_mem[1 ]
371342 end
343+
372344 return nothing
373345end
374346
0 commit comments