Skip to content

Commit f54205e

Browse files
author
Diptorup Deb
committed
Adds overloads for group indexing functions
1 parent 6290156 commit f54205e

File tree

3 files changed

+118
-9
lines changed

3 files changed

+118
-9
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,26 @@ def _intrinsic_spirv_workgroup_size(
123123
)
124124

125125

126+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
127+
def _intrinsic_spirv_workgroup_id(
128+
ty_context, ty_dim # pylint: disable=unused-argument
129+
):
130+
"""Generates instruction to get index from BuiltInWorkgroupId."""
131+
return _intrinsic_spirv_global_index_const(
132+
ty_context, ty_dim, "BuiltInWorkgroupId"
133+
)
134+
135+
136+
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
137+
def _intrinsic_spirv_numworkgroups(
138+
ty_context, ty_dim # pylint: disable=unused-argument
139+
):
140+
"""Generates instruction to get index from BuiltInNumWorkgroups."""
141+
return _intrinsic_spirv_global_index_const(
142+
ty_context, ty_dim, "BuiltInNumWorkgroups"
143+
)
144+
145+
126146
def generate_index_overload(_type, _intrinsic):
127147
"""Generates overload for the index method that generates specific IR from
128148
provided intrinsic."""
@@ -167,6 +187,9 @@ def ol_item_get_index_impl(item, dim):
167187
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
168188
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
169189
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
190+
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
191+
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
192+
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
170193
]
171194

172195
for index_overload in _index_const_overload_methods:

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,11 @@ def get_group_linear_id(self):
5252
+ (self._index[2])
5353
)
5454

55-
def get_group_range(self):
56-
"""Returns a range representing the number of groups in the nd-range."""
57-
return self._group_range
55+
def get_group_range(self, dim):
56+
"""Returns a the extent of the range representing the number of groups
57+
in the nd-range for a specified dimension.
58+
"""
59+
return self._group_range[dim]
5860

5961
def get_group_linear_range(self):
6062
"""Return the total number of work-groups in the nd_range."""
@@ -64,12 +66,12 @@ def get_group_linear_range(self):
6466

6567
return num_wg
6668

67-
def get_local_range(self):
68-
"""Returns a SYCL range representing all dimensions of the local
69-
range. This local range may have been provided by the programmer, or
70-
chosen by the SYCL runtime.
69+
def get_local_range(self, dim):
70+
"""Returns the extent of the SYCL range representing all dimensions
71+
of the local range for a specified dimension. This local range may
72+
have been provided by the programmer, or chosen by the SYCL runtime.
7173
"""
72-
return self._local_range
74+
return self._local_range[dim]
7375

7476
def get_local_linear_range(self):
7577
"""Return the total number of work-items in the work-group."""

numba_dpex/tests/experimental/test_index_space_ids.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
import numba_dpex as dpex
1313
import numba_dpex.experimental as dpex_exp
14-
from numba_dpex.kernel_api import Item, NdItem
14+
from numba_dpex.kernel_api import Item, NdItem, NdRange
15+
from numba_dpex.kernel_api import call_kernel as kapi_call_kernel
1516
from numba_dpex.tests._helper import skip_windows
1617

1718
_SIZE = 16
@@ -63,6 +64,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a):
6364
a[i] = 1
6465

6566

67+
def _get_group_id_driver(nditem: NdItem, a):
68+
i = nditem.get_global_id(0)
69+
g = nditem.get_group()
70+
a[i] = g.get_group_id(0)
71+
72+
73+
def _get_group_range_driver(nditem: NdItem, a):
74+
i = nditem.get_global_id(0)
75+
g = nditem.get_group()
76+
a[i] = g.get_group_range(0)
77+
78+
79+
def _get_group_local_range_driver(nditem: NdItem, a):
80+
i = nditem.get_global_id(0)
81+
g = nditem.get_group()
82+
a[i] = g.get_local_range(0)
83+
84+
6685
def test_item_get_id():
6786
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
6887
dpex_exp.call_kernel(set_ones_item, dpex.Range(a.size), a)
@@ -155,6 +174,71 @@ def test_no_item():
155174
)
156175

157176

177+
def test_get_group_id():
178+
179+
global_size = 100
180+
group_size = 20
181+
num_groups = global_size // group_size
182+
183+
a = dpnp.empty(global_size, dtype=dpnp.int32)
184+
ka = dpnp.empty(global_size, dtype=dpnp.int32)
185+
expected = np.empty(global_size, dtype=np.int32)
186+
ndrange = NdRange((global_size,), (group_size,))
187+
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_id_driver), ndrange, a)
188+
kapi_call_kernel(_get_group_id_driver, ndrange, ka)
189+
190+
for gid in range(num_groups):
191+
for lid in range(group_size):
192+
expected[gid * group_size + lid] = gid
193+
194+
assert np.array_equal(a.asnumpy(), expected)
195+
assert np.array_equal(ka.asnumpy(), expected)
196+
197+
198+
def test_get_group_range():
199+
200+
global_size = 100
201+
group_size = 20
202+
num_groups = global_size // group_size
203+
204+
a = dpnp.empty(global_size, dtype=dpnp.int32)
205+
ka = dpnp.empty(global_size, dtype=dpnp.int32)
206+
expected = np.empty(global_size, dtype=np.int32)
207+
ndrange = NdRange((global_size,), (group_size,))
208+
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_range_driver), ndrange, a)
209+
kapi_call_kernel(_get_group_range_driver, ndrange, ka)
210+
211+
for gid in range(num_groups):
212+
for lid in range(group_size):
213+
expected[gid * group_size + lid] = num_groups
214+
215+
assert np.array_equal(a.asnumpy(), expected)
216+
assert np.array_equal(ka.asnumpy(), expected)
217+
218+
219+
def test_get_group_local_range():
220+
221+
global_size = 100
222+
group_size = 20
223+
num_groups = global_size // group_size
224+
225+
a = dpnp.empty(global_size, dtype=dpnp.int32)
226+
ka = dpnp.empty(global_size, dtype=dpnp.int32)
227+
expected = np.empty(global_size, dtype=np.int32)
228+
ndrange = NdRange((global_size,), (group_size,))
229+
dpex_exp.call_kernel(
230+
dpex_exp.kernel(_get_group_local_range_driver), ndrange, a
231+
)
232+
kapi_call_kernel(_get_group_local_range_driver, ndrange, ka)
233+
234+
for gid in range(num_groups):
235+
for lid in range(group_size):
236+
expected[gid * group_size + lid] = group_size
237+
238+
assert np.array_equal(a.asnumpy(), expected)
239+
assert np.array_equal(ka.asnumpy(), expected)
240+
241+
158242
I_SIZE, J_SIZE, K_SIZE = 2, 3, 4
159243

160244

0 commit comments

Comments
 (0)