Skip to content

Commit 99c4aa1

Browse files
author
Diptorup Deb
authored
Merge pull request #1359 from IntelPython/feature/overload_dimensions_attribute_for_indexers
Feature/overload dimensions attribute for indexers
2 parents 6e28999 + 554e52d commit 99c4aa1

File tree

4 files changed

+77
-7
lines changed

4 files changed

+77
-7
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import llvmlite.ir as llvmir
1010
from numba.core import cgutils, types
1111
from numba.core.errors import TypingError
12-
from numba.extending import intrinsic, overload_method
12+
from numba.extending import intrinsic, overload_attribute, overload_method
1313

1414
from numba_dpex.core.types.kernel_api.index_space_ids import (
1515
GroupType,
@@ -248,3 +248,24 @@ def ol_nd_item_get_group_impl(nd_item):
248248
return _intrinsic_get_group(nd_item)
249249

250250
return ol_nd_item_get_group_impl
251+
252+
253+
@overload_attribute(GroupType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
254+
@overload_attribute(ItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
255+
@overload_attribute(
256+
NdItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME
257+
)
258+
def ol_nd_item_dimensions(item):
259+
"""
260+
SPIR-V overload for :meth:`numba_dpex.kernel_api.<generic_item>.dimensions`.
261+
262+
Generates the same LLVM IR instruction as dpcpp for the
263+
`sycl::<generic_item>::dimensions` attribute.
264+
"""
265+
dimensions = item.ndim
266+
267+
# pylint: disable=unused-argument
268+
def ol_nd_item_get_group_impl(item):
269+
return dimensions
270+
271+
return ol_nd_item_get_group_impl

numba_dpex/experimental/typeof.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ def typeof_item(val: Item, c):
6868
Returns: A numba_dpex.experimental.core.types.kernel_api.items.ItemType
6969
instance.
7070
"""
71-
return ItemType(val.ndim)
71+
return ItemType(val.dimensions)
7272

7373

7474
@typeof_impl.register(NdItem)
75-
def typeof_nditem(val, c):
75+
def typeof_nditem(val: NdItem, c):
7676
"""Registers the type inference implementation function for a
7777
numba_dpex.kernel_api.NdItem PyObject.
7878
@@ -83,4 +83,4 @@ def typeof_nditem(val, c):
8383
Returns: A numba_dpex.experimental.core.types.kernel_api.items.NdItemType
8484
instance.
8585
"""
86-
return NdItemType(val.ndim)
86+
return NdItemType(val.dimensions)

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ def leader(self):
9898
"""
9999
return self._leader
100100

101+
@property
102+
def dimensions(self) -> int:
103+
"""Returns the rank of a Group object.
104+
Returns:
105+
int: Number of dimensions in the Group object
106+
"""
107+
return self._global_range.ndim
108+
101109
@leader.setter
102110
def leader(self, work_item_id):
103111
"""Sets the leader attribute for the group."""
@@ -147,7 +155,7 @@ def get_range(self, idx):
147155
return self._extent[idx]
148156

149157
@property
150-
def ndim(self) -> int:
158+
def dimensions(self) -> int:
151159
"""Returns the rank of a Item object.
152160
153161
Returns:
@@ -228,10 +236,10 @@ def get_group(self):
228236
return self._group
229237

230238
@property
231-
def ndim(self) -> int:
239+
def dimensions(self) -> int:
232240
"""Returns the rank of a NdItem object.
233241
234242
Returns:
235243
int: Number of dimensions in the NdItem object
236244
"""
237-
return self._global_item.ndim
245+
return self._global_item.dimensions

numba_dpex/tests/experimental/test_index_space_ids.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a):
6363
a[i] = 1
6464

6565

66+
@dpex_exp.kernel
67+
def set_dimensions_item(item: Item, a):
68+
i = item.get_id(0)
69+
a[i] = item.dimensions
70+
71+
72+
@dpex_exp.kernel
73+
def set_dimensions_nd_item(nd_item: NdItem, a):
74+
i = nd_item.get_global_id(0)
75+
a[i] = nd_item.dimensions
76+
77+
78+
@dpex_exp.kernel
79+
def set_dimensions_group(nd_item: NdItem, a):
80+
i = nd_item.get_global_id(0)
81+
a[i] = nd_item.get_group().dimensions
82+
83+
6684
def _get_group_id_driver(nditem: NdItem, a):
6785
i = nditem.get_global_id(0)
6886
g = nditem.get_group()
@@ -149,6 +167,29 @@ def test_nd_item_get_local_id():
149167
)
150168

151169

170+
@pytest.mark.parametrize("dims", [1, 2, 3])
171+
def test_item_dimensions(dims):
172+
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
173+
rng = [1] * dims
174+
rng[0] = a.size
175+
dpex_exp.call_kernel(set_dimensions_item, dpex.Range(*rng), a)
176+
177+
assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32))
178+
179+
180+
@pytest.mark.parametrize("dims", [1, 2, 3])
181+
@pytest.mark.parametrize(
182+
"kernel", [set_dimensions_nd_item, set_dimensions_group]
183+
)
184+
def test_nd_item_dimensions(dims, kernel):
185+
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
186+
rng, grp = [1] * dims, [1] * dims
187+
rng[0], grp[0] = a.size, _GROUP_SIZE
188+
dpex_exp.call_kernel(kernel, dpex.NdRange(rng, grp), a)
189+
190+
assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32))
191+
192+
152193
def test_error_item_get_global_id():
153194
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
154195

0 commit comments

Comments
 (0)