Skip to content

Commit 437efe2

Browse files
committed
Add get_group_linear_id to overaloads
1 parent 27f8199 commit 437efe2

File tree

3 files changed

+41
-20
lines changed

3 files changed

+41
-20
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ItemType,
1717
NdItemType,
1818
)
19-
from numba_dpex.kernel_api import Item, NdItem
19+
from numba_dpex.kernel_api import Group, Item, NdItem
2020
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
2121

2222
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
@@ -303,3 +303,4 @@ def register_jitable_method(type_, method):
303303
register_jitable_method(ItemType, Item.get_linear_id)
304304
register_jitable_method(NdItemType, NdItem.get_global_linear_id)
305305
register_jitable_method(NdItemType, NdItem.get_local_linear_id)
306+
register_jitable_method(GroupType, Group.get_group_linear_id)

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,20 @@ def get_group_id(self, dim):
4242

4343
def get_group_linear_id(self):
4444
"""Returns a linearized version of the work-group index."""
45-
if len(self._index) == 1:
46-
return self._index[0]
47-
if len(self._index) == 2:
48-
return self._index[0] * self._group_range[1] + self._index[1]
45+
if self.dimensions == 1:
46+
return self.get_group_id(0)
47+
if self.dimensions == 2:
48+
return self.get_group_id(0) * self.get_group_range(
49+
1
50+
) + self.get_group_id(1)
4951
return (
50-
(self._index[0] * self._group_range[1] * self._group_range[2])
51-
+ (self._index[1] * self._group_range[2])
52-
+ (self._index[2])
52+
(
53+
self.get_group_id(0)
54+
* self.get_group_range(1)
55+
* self.get_group_range(2)
56+
)
57+
+ (self.get_group_id(1) * self.get_group_range(2))
58+
+ (self.get_group_id(2))
5359
)
5460

5561
def get_group_range(self, dim):

numba_dpex/tests/experimental/test_index_space_ids.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def _get_group_id_driver(nditem: NdItem, a):
8787
a[i] = g.get_group_id(0)
8888

8989

90+
def _get_group_linear_id_driver(nditem: NdItem, a):
91+
i = nditem.get_global_linear_id()
92+
g = nditem.get_group()
93+
a[i] = g.get_group_linear_id()
94+
95+
9096
def _get_group_range_driver(nditem: NdItem, a):
9197
i = nditem.get_global_id(0)
9298
g = nditem.get_group()
@@ -206,21 +212,29 @@ def test_no_item():
206212
)
207213

208214

209-
def test_get_group_id():
210-
global_size = 100
211-
group_size = 20
212-
num_groups = global_size // group_size
215+
@pytest.mark.parametrize(
216+
"driver,rng",
217+
[
218+
(_get_group_id_driver, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
219+
(_get_group_linear_id_driver, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
220+
(
221+
_get_group_linear_id_driver,
222+
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
223+
),
224+
],
225+
)
226+
def test_get_group_id(driver, rng):
227+
num_groups = _SIZE // _GROUP_SIZE
213228

214-
a = dpnp.empty(global_size, dtype=dpnp.int32)
215-
ka = dpnp.empty(global_size, dtype=dpnp.int32)
216-
expected = np.empty(global_size, dtype=np.int32)
217-
ndrange = NdRange((global_size,), (group_size,))
218-
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_id_driver), ndrange, a)
219-
kapi_call_kernel(_get_group_id_driver, ndrange, ka)
229+
a = dpnp.empty(_SIZE, dtype=dpnp.int32)
230+
ka = dpnp.empty(_SIZE, dtype=dpnp.int32)
231+
expected = np.empty(_SIZE, dtype=np.int32)
232+
dpex_exp.call_kernel(dpex_exp.kernel(driver), rng, a)
233+
kapi_call_kernel(driver, rng, ka)
220234

221235
for gid in range(num_groups):
222-
for lid in range(group_size):
223-
expected[gid * group_size + lid] = gid
236+
for lid in range(_GROUP_SIZE):
237+
expected[gid * _GROUP_SIZE + lid] = gid
224238

225239
assert np.array_equal(a.asnumpy(), expected)
226240
assert np.array_equal(ka.asnumpy(), expected)

0 commit comments

Comments
 (0)