Skip to content

Commit 8e78498

Browse files
author
Diptorup Deb
committed
Port sum_reduction_recursive to kernel API.
1 parent ddeefba commit 8e78498

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

numba_dpex/examples/kernel/sum_reduction_recursive.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
@ndpx.kernel
2020
def sum_reduction_kernel(nditem: kapi.NdItem, A, input_size, partial_sums, slm):
21-
local_id = ndpx.get_local_id(0)
22-
global_id = ndpx.get_global_id(0)
23-
group_size = ndpx.get_local_size(0)
24-
group_id = ndpx.get_group_id(0)
25-
21+
local_id = nditem.get_local_id(0)
22+
global_id = nditem.get_global_id(0)
23+
group_size = nditem.get_local_range(0)
24+
gr = nditem.get_group()
25+
group_id = gr.get_group_id(0)
2626
slm[local_id] = 0
2727

2828
if global_id < input_size:
@@ -32,7 +32,7 @@ def sum_reduction_kernel(nditem: kapi.NdItem, A, input_size, partial_sums, slm):
3232
stride = group_size // 2
3333
while stride > 0:
3434
# Waiting for each 2x2 addition into given workgroup
35-
ndpx.barrier(ndpx.LOCAL_MEM_FENCE)
35+
kapi.group_barrier(gr)
3636

3737
# Add elements 2 by 2 between local_id and local_id + stride
3838
if local_id < stride:

0 commit comments

Comments
 (0)