Skip to content

Commit ff3f0e9

Browse files
committed
Use computed occupancy for amdgpu parallel_reduce
1 parent 43ef00f commit ff3f0e9

File tree

1 file changed

+49
-77
lines changed

1 file changed

+49
-77
lines changed

ext/JACCAMDGPU/JACCAMDGPU.jl

Lines changed: 49 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,32 @@ end
209209

210210
function JACC.parallel_reduce(
211211
::AMDGPUBackend, N::Integer, op, f::Function, x...; init)
212-
numThreads = 512
213-
threads = numThreads
214-
blocks = ceil(Int, N / threads)
215-
ret = fill!(AMDGPU.ROCArray{typeof(init)}(undef, blocks), init)
212+
ret_inst = AMDGPU.ROCArray{typeof(init)}(undef, 0)
213+
kernel1 = @roc launch=false _parallel_reduce_amdgpu(
214+
N, op, ret_inst, f, x...)
215+
config1 = AMDGPU.launch_configuration(kernel1)
216+
threads1 = config1.groupsize
217+
216218
rret = AMDGPU.ROCArray([init])
217-
@roc groupsize=threads gridsize=blocks _parallel_reduce_amdgpu(
218-
N, op, ret, f, x...)
219+
kernel2 = @roc launch=false reduce_kernel_amdgpu(1, op, ret_inst, rret)
220+
config2 = AMDGPU.launch_configuration(kernel2)
221+
threads2 = config2.groupsize
222+
223+
threads = min(threads1, threads2, 512)
224+
blocks = cld(N, threads)
225+
226+
shmem_size = threads * sizeof(Float64)
227+
228+
ret = fill!(AMDGPU.ROCArray{typeof(init)}(undef, blocks), init)
229+
230+
kernel1(
231+
N, op, ret, f, x...; groupsize = threads,
232+
gridsize = blocks, shmem = shmem_size)
219233
AMDGPU.synchronize()
220-
@roc groupsize=threads gridsize=1 reduce_kernel_amdgpu(
221-
blocks, op, ret, rret)
234+
kernel2(blocks, op, ret, rret; groupsize = threads,
235+
gridsize = 1, shmem = shmem_size)
222236
AMDGPU.synchronize()
237+
223238
return Base.Array(rret)[]
224239
end
225240

@@ -268,7 +283,8 @@ function _parallel_for_amdgpu_LMN((L, M, N), f, x...)
268283
end
269284

270285
function _parallel_reduce_amdgpu(N, op, ret, f, x...)
271-
shared_mem = @ROCStaticLocalArray(eltype(ret), 512)
286+
shmem_length = workgroupDim().x
287+
shared_mem = @ROCDynamicLocalArray(eltype(ret), shmem_length)
272288
i = (workgroupIdx().x - 1) * workgroupDim().x + workitemIdx().x
273289
ti = workitemIdx().x
274290
shared_mem[ti] = ret[workgroupIdx().x]
@@ -278,97 +294,53 @@ function _parallel_reduce_amdgpu(N, op, ret, f, x...)
278294
shared_mem[ti] = tmp
279295
end
280296
AMDGPU.sync_workgroup()
281-
if (ti <= 256)
282-
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 256])
283-
end
284-
AMDGPU.sync_workgroup()
285-
if (ti <= 128)
286-
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 128])
287-
end
288-
AMDGPU.sync_workgroup()
289-
if (ti <= 64)
290-
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 64])
291-
end
292-
AMDGPU.sync_workgroup()
293-
if (ti <= 32)
294-
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 32])
295-
end
296-
AMDGPU.sync_workgroup()
297-
if (ti <= 16)
298-
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 16])
299-
end
300-
AMDGPU.sync_workgroup()
301-
if (ti <= 8)
302-
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 8])
303-
end
304-
AMDGPU.sync_workgroup()
305-
if (ti <= 4)
306-
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 4])
307-
end
308-
AMDGPU.sync_workgroup()
309-
if (ti <= 2)
310-
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 2])
297+
tn = div(shmem_length, 2)
298+
while tn > 1
299+
if ti <= tn
300+
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + tn])
301+
end
302+
AMDGPU.sync_workgroup()
303+
tn = div(tn, 2)
311304
end
312-
AMDGPU.sync_workgroup()
313-
if (ti == 1)
305+
if ti == 1
314306
shared_mem[ti] = op(shared_mem[ti], shared_mem[ti + 1])
315307
ret[workgroupIdx().x] = shared_mem[ti]
316308
end
317-
AMDGPU.sync_workgroup()
309+
318310
return nothing
319311
end
320312

321313
function reduce_kernel_amdgpu(N, op, red, ret)
322-
shared_mem = @ROCStaticLocalArray(eltype(ret), 512)
314+
shmem_length = workgroupDim().x
315+
shared_mem = @ROCDynamicLocalArray(eltype(ret), shmem_length)
323316
i = workitemIdx().x
324317
ii = i
325318
tmp = ret[1]
326-
if N > 512
319+
if N > shmem_length
327320
while ii <= N
328321
tmp = op(tmp, @inbounds red[ii])
329-
ii += 512
322+
ii += shmem_length
330323
end
331324
elseif (i <= N)
332325
tmp = @inbounds red[i]
333326
end
334327
shared_mem[i] = tmp
335328
AMDGPU.sync_workgroup()
336-
if (i <= 256)
337-
shared_mem[i] = op(shared_mem[i], shared_mem[i + 256])
338-
end
339-
AMDGPU.sync_workgroup()
340-
if (i <= 128)
341-
shared_mem[i] = op(shared_mem[i], shared_mem[i + 128])
342-
end
343-
AMDGPU.sync_workgroup()
344-
if (i <= 64)
345-
shared_mem[i] = op(shared_mem[i], shared_mem[i + 64])
346-
end
347-
AMDGPU.sync_workgroup()
348-
if (i <= 32)
349-
shared_mem[i] = op(shared_mem[i], shared_mem[i + 32])
350-
end
351-
AMDGPU.sync_workgroup()
352-
if (i <= 16)
353-
shared_mem[i] = op(shared_mem[i], shared_mem[i + 16])
354-
end
355-
AMDGPU.sync_workgroup()
356-
if (i <= 8)
357-
shared_mem[i] = op(shared_mem[i], shared_mem[i + 8])
358-
end
359-
AMDGPU.sync_workgroup()
360-
if (i <= 4)
361-
shared_mem[i] = op(shared_mem[i], shared_mem[i + 4])
362-
end
363-
AMDGPU.sync_workgroup()
364-
if (i <= 2)
365-
shared_mem[i] = op(shared_mem[i], shared_mem[i + 2])
329+
330+
tn = div(shmem_length, 2)
331+
while tn > 1
332+
if i <= tn
333+
shared_mem[i] = op(shared_mem[i], shared_mem[i + tn])
334+
end
335+
AMDGPU.sync_workgroup()
336+
tn = div(tn, 2)
366337
end
367-
AMDGPU.sync_workgroup()
368-
if (i == 1)
338+
339+
if i == 1
369340
shared_mem[i] = op(shared_mem[i], shared_mem[i + 1])
370341
ret[1] = shared_mem[1]
371342
end
343+
372344
return nothing
373345
end
374346

0 commit comments

Comments
 (0)