Skip to content

Commit 554e52d

Browse files
committed
Overload generic item's attribute 'dimensions'
1 parent bff5c52 commit 554e52d

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import llvmlite.ir as llvmir
1010
from numba.core import cgutils, types
1111
from numba.core.errors import TypingError
12-
from numba.extending import intrinsic, overload_method
12+
from numba.extending import intrinsic, overload_attribute, overload_method
1313

1414
from numba_dpex.core.types.kernel_api.index_space_ids import (
1515
GroupType,
@@ -248,3 +248,24 @@ def ol_nd_item_get_group_impl(nd_item):
248248
return _intrinsic_get_group(nd_item)
249249

250250
return ol_nd_item_get_group_impl
251+
252+
253+
@overload_attribute(GroupType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
254+
@overload_attribute(ItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
255+
@overload_attribute(
256+
NdItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME
257+
)
258+
def ol_nd_item_dimensions(item):
259+
"""
260+
SPIR-V overload for :meth:`numba_dpex.kernel_api.<generic_item>.dimensions`.
261+
262+
Generates the same LLVM IR instruction as dpcpp for the
263+
`sycl::<generic_item>::dimensions` attribute.
264+
"""
265+
dimensions = item.ndim
266+
267+
# pylint: disable=unused-argument
268+
def ol_nd_item_get_group_impl(item):
269+
return dimensions
270+
271+
return ol_nd_item_get_group_impl

numba_dpex/tests/experimental/test_index_space_ids.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a):
6363
a[i] = 1
6464

6565

66+
@dpex_exp.kernel
67+
def set_dimensions_item(item: Item, a):
68+
i = item.get_id(0)
69+
a[i] = item.dimensions
70+
71+
72+
@dpex_exp.kernel
73+
def set_dimensions_nd_item(nd_item: NdItem, a):
74+
i = nd_item.get_global_id(0)
75+
a[i] = nd_item.dimensions
76+
77+
78+
@dpex_exp.kernel
79+
def set_dimensions_group(nd_item: NdItem, a):
80+
i = nd_item.get_global_id(0)
81+
a[i] = nd_item.get_group().dimensions
82+
83+
6684
def _get_group_id_driver(nditem: NdItem, a):
6785
i = nditem.get_global_id(0)
6886
g = nditem.get_group()
@@ -149,6 +167,29 @@ def test_nd_item_get_local_id():
149167
)
150168

151169

170+
@pytest.mark.parametrize("dims", [1, 2, 3])
171+
def test_item_dimensions(dims):
172+
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
173+
rng = [1] * dims
174+
rng[0] = a.size
175+
dpex_exp.call_kernel(set_dimensions_item, dpex.Range(*rng), a)
176+
177+
assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32))
178+
179+
180+
@pytest.mark.parametrize("dims", [1, 2, 3])
181+
@pytest.mark.parametrize(
182+
"kernel", [set_dimensions_nd_item, set_dimensions_group]
183+
)
184+
def test_nd_item_dimensions(dims, kernel):
185+
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
186+
rng, grp = [1] * dims, [1] * dims
187+
rng[0], grp[0] = a.size, _GROUP_SIZE
188+
dpex_exp.call_kernel(kernel, dpex.NdRange(rng, grp), a)
189+
190+
assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32))
191+
192+
152193
def test_error_item_get_global_id():
153194
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
154195

0 commit comments

Comments
 (0)