Skip to content

Commit e998f72

Browse files
author
Diptorup Deb
authored
Merge pull request #1330 from IntelPython/feature/group_id_overloads
Adds overloads for group indexing functions
2 parents 6290156 + a106444 commit e998f72

File tree

3 files changed

+121
-9
lines changed

3 files changed

+121
-9
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,26 @@ def _intrinsic_spirv_workgroup_size(
123123
)
124124

125125

126+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
127+
def _intrinsic_spirv_workgroup_id(
128+
ty_context, ty_dim # pylint: disable=unused-argument
129+
):
130+
"""Generates instruction to get index from BuiltInWorkgroupId."""
131+
return _intrinsic_spirv_global_index_const(
132+
ty_context, ty_dim, "BuiltInWorkgroupId"
133+
)
134+
135+
136+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
137+
def _intrinsic_spirv_numworkgroups(
138+
ty_context, ty_dim # pylint: disable=unused-argument
139+
):
140+
"""Generates instruction to get index from BuiltInNumWorkgroups."""
141+
return _intrinsic_spirv_global_index_const(
142+
ty_context, ty_dim, "BuiltInNumWorkgroups"
143+
)
144+
145+
126146
def generate_index_overload(_type, _intrinsic):
127147
"""Generates overload for the index method that generates specific IR from
128148
provided intrinsic."""
@@ -167,6 +187,9 @@ def ol_item_get_index_impl(item, dim):
167187
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
168188
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
169189
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
190+
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
191+
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
192+
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
170193
]
171194

172195
for index_overload in _index_const_overload_methods:

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ def get_group_linear_id(self):
5252
+ (self._index[2])
5353
)
5454

55-
def get_group_range(self):
56-
"""Returns a range representing the number of groups in the nd-range."""
57-
return self._group_range
55+
def get_group_range(self, dim):
56+
"""Returns a the extent of the range representing the number of groups
57+
in the nd-range for a specified dimension.
58+
"""
59+
return self._group_range[dim]
5860

5961
def get_group_linear_range(self):
6062
"""Return the total number of work-groups in the nd_range."""
@@ -64,12 +66,12 @@ def get_group_linear_range(self):
6466

6567
return num_wg
6668

67-
def get_local_range(self):
68-
"""Returns a SYCL range representing all dimensions of the local
69-
range. This local range may have been provided by the programmer, or
70-
chosen by the SYCL runtime.
69+
def get_local_range(self, dim):
70+
"""Returns the extent of the SYCL range representing all dimensions
71+
of the local range for a specified dimension. This local range may
72+
have been provided by the programmer, or chosen by the SYCL runtime.
7173
"""
72-
return self._local_range
74+
return self._local_range[dim]
7375

7476
def get_local_linear_range(self):
7577
"""Return the total number of work-items in the work-group."""

numba_dpex/tests/experimental/test_index_space_ids.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
import numba_dpex as dpex
1313
import numba_dpex.experimental as dpex_exp
14-
from numba_dpex.kernel_api import Item, NdItem
14+
from numba_dpex.kernel_api import Item, NdItem, NdRange
15+
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel
1516
from numba_dpex.tests._helper import skip_windows
1617

1718
_SIZE = 16
@@ -63,6 +64,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a):
6364
a[i] = 1
6465

6566

67+
def _get_group_id_driver(nditem: NdItem, a):
68+
i = nditem.get_global_id(0)
69+
g = nditem.get_group()
70+
a[i] = g.get_group_id(0)
71+
72+
73+
def _get_group_range_driver(nditem: NdItem, a):
74+
i = nditem.get_global_id(0)
75+
g = nditem.get_group()
76+
a[i] = g.get_group_range(0)
77+
78+
79+
def _get_group_local_range_driver(nditem: NdItem, a):
80+
i = nditem.get_global_id(0)
81+
g = nditem.get_group()
82+
a[i] = g.get_local_range(0)
83+
84+
6685
def test_item_get_id():
6786
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
6887
dpex_exp.call_kernel(set_ones_item, dpex.Range(a.size), a)
@@ -155,6 +174,74 @@ def test_no_item():
155174
)
156175

157176

177+
# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
178+
@skip_windows
179+
def test_get_group_id():
180+
global_size = 100
181+
group_size = 20
182+
num_groups = global_size // group_size
183+
184+
a = dpnp.empty(global_size, dtype=dpnp.int32)
185+
ka = dpnp.empty(global_size, dtype=dpnp.int32)
186+
expected = np.empty(global_size, dtype=np.int32)
187+
ndrange = NdRange((global_size,), (group_size,))
188+
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_id_driver), ndrange, a)
189+
kapi_call_kernel(_get_group_id_driver, ndrange, ka)
190+
191+
for gid in range(num_groups):
192+
for lid in range(group_size):
193+
expected[gid * group_size + lid] = gid
194+
195+
assert np.array_equal(a.asnumpy(), expected)
196+
assert np.array_equal(ka.asnumpy(), expected)
197+
198+
199+
# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
200+
@skip_windows
201+
def test_get_group_range():
202+
global_size = 100
203+
group_size = 20
204+
num_groups = global_size // group_size
205+
206+
a = dpnp.empty(global_size, dtype=dpnp.int32)
207+
ka = dpnp.empty(global_size, dtype=dpnp.int32)
208+
expected = np.empty(global_size, dtype=np.int32)
209+
ndrange = NdRange((global_size,), (group_size,))
210+
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_range_driver), ndrange, a)
211+
kapi_call_kernel(_get_group_range_driver, ndrange, ka)
212+
213+
for gid in range(num_groups):
214+
for lid in range(group_size):
215+
expected[gid * group_size + lid] = num_groups
216+
217+
assert np.array_equal(a.asnumpy(), expected)
218+
assert np.array_equal(ka.asnumpy(), expected)
219+
220+
221+
# TODO: https://github.com/IntelPython/numba-dpex/issues/1308
222+
@skip_windows
223+
def test_get_group_local_range():
224+
global_size = 100
225+
group_size = 20
226+
num_groups = global_size // group_size
227+
228+
a = dpnp.empty(global_size, dtype=dpnp.int32)
229+
ka = dpnp.empty(global_size, dtype=dpnp.int32)
230+
expected = np.empty(global_size, dtype=np.int32)
231+
ndrange = NdRange((global_size,), (group_size,))
232+
dpex_exp.call_kernel(
233+
dpex_exp.kernel(_get_group_local_range_driver), ndrange, a
234+
)
235+
kapi_call_kernel(_get_group_local_range_driver, ndrange, ka)
236+
237+
for gid in range(num_groups):
238+
for lid in range(group_size):
239+
expected[gid * group_size + lid] = group_size
240+
241+
assert np.array_equal(a.asnumpy(), expected)
242+
assert np.array_equal(ka.asnumpy(), expected)
243+
244+
158245
I_SIZE, J_SIZE, K_SIZE = 2, 3, 4
159246

160247

0 commit comments

Comments
 (0)