Skip to content

Commit 27f8199

Browse files
committed
Add linear indexing overloads for item & nd_item
1 parent 674ef5b commit 27f8199

File tree

3 files changed

+128
-12
lines changed

3 files changed

+128
-12
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ItemType,
1717
NdItemType,
1818
)
19+
from numba_dpex.kernel_api import Item, NdItem
1920
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
2021

2122
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
@@ -274,3 +275,31 @@ def ol_nd_item_get_group_impl(item):
274275
return dimensions
275276

276277
return ol_nd_item_get_group_impl
278+
279+
280+
def _generate_method_overload(method):
281+
"""Generates naive method overload with no argument, except self."""
282+
283+
def ol_method(self): # pylint: disable=unused-argument
284+
return method
285+
286+
return ol_method
287+
288+
289+
def register_jitable_method(type_, method):
290+
"""
291+
Register a regular python method that can be executed by the
292+
python interpreter and can be compiled into a nopython
293+
function when referenced by other jit'ed functions.
294+
295+
Same as register_jitable, but for methods with no arguments.
296+
"""
297+
overloaded_method = _generate_method_overload(method)
298+
overload_method(type_, method.__name__, target=DPEX_KERNEL_EXP_TARGET_NAME)(
299+
overloaded_method
300+
)
301+
302+
303+
register_jitable_method(ItemType, Item.get_linear_id)
304+
register_jitable_method(NdItemType, NdItem.get_global_linear_id)
305+
register_jitable_method(NdItemType, NdItem.get_local_linear_id)

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,14 @@ def get_linear_id(self):
128128
Returns:
129129
int: The linear id.
130130
"""
131-
if len(self._extent) == 1:
132-
return self._index[0]
133-
if len(self._extent) == 2:
134-
return self._index[0] * self._extent[1] + self._index[1]
131+
if self.dimensions == 1:
132+
return self.get_id(0)
133+
if self.dimensions == 2:
134+
return self.get_id(0) * self.get_range(1) + self.get_id(1)
135135
return (
136-
(self._index[0] * self._extent[1] * self._extent[2])
137-
+ (self._index[1] * self._extent[2])
138-
+ (self._index[2])
136+
(self.get_id(0) * self.get_range(1) * self.get_range(2))
137+
+ (self.get_id(1) * self.get_range(2))
138+
+ (self.get_id(2))
139139
)
140140

141141
def get_id(self, idx):
@@ -193,7 +193,24 @@ def get_global_linear_id(self):
193193
Returns:
194194
int: The global linear id.
195195
"""
196-
return self._global_item.get_linear_id()
196+
# Instead of calling self._global_item.get_linear_id(), the linearization
197+
# logic is duplicated here so that the method can be JIT compiled by
198+
# numba-dpex and works in both Python and Numba nopython modes.
199+
if self.dimensions == 1:
200+
return self.get_global_id(0)
201+
if self.dimensions == 2:
202+
return self.get_global_id(0) * self.get_global_range(
203+
1
204+
) + self.get_global_id(1)
205+
return (
206+
(
207+
self.get_global_id(0)
208+
* self.get_global_range(1)
209+
* self.get_global_range(2)
210+
)
211+
+ (self.get_global_id(1) * self.get_global_range(2))
212+
+ (self.get_global_id(2))
213+
)
197214

198215
def get_local_id(self, idx):
199216
"""Get the local id for a specific dimension.
@@ -210,7 +227,24 @@ def get_local_linear_id(self):
210227
Returns:
211228
int: The local linear id.
212229
"""
213-
return self._local_item.get_linear_id()
230+
# Instead of calling self._local_item.get_linear_id(), the linearization
231+
# logic is duplicated here so that the method can be JIT compiled by
232+
# numba-dpex and works in both Python and Numba nopython modes.
233+
if self.dimensions == 1:
234+
return self.get_local_id(0)
235+
if self.dimensions == 2:
236+
return self.get_local_id(0) * self.get_local_range(
237+
1
238+
) + self.get_local_id(1)
239+
return (
240+
(
241+
self.get_local_id(0)
242+
* self.get_local_range(1)
243+
* self.get_local_range(2)
244+
)
245+
+ (self.get_local_id(1) * self.get_local_range(2))
246+
+ (self.get_local_id(2))
247+
)
214248

215249
def get_global_range(self, idx):
216250
"""Get the global range size for a specific dimension.

numba_dpex/tests/experimental/test_index_space_ids.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,18 +277,71 @@ def set_3d_ones_item(item: Item, a):
277277
j = item.get_id(1)
278278
k = item.get_id(2)
279279

280-
# Since we have different sizes for each dimention, wrong order will result
280+
# Since we have different sizes for each dimension, wrong order will result
281281
# that some indexes will be set twice and some won't be set.
282282
index = i + I_SIZE * (j + J_SIZE * k)
283283

284284
a[index] = 1
285285

286286

287-
def test_index_order():
287+
@dpex_exp.kernel
288+
def set_3d_ones_item_linear(item: Item, a):
289+
# Since we have different sizes for each dimension, wrong order will result
290+
# that some indexes will be set twice and some won't be set.
291+
index = item.get_linear_id()
292+
293+
a[index] = 1
294+
295+
296+
@dpex_exp.kernel
297+
def set_3d_ones_nd_item_linear(nd_item: NdItem, a):
298+
# Since we have different sizes for each dimension, wrong order will result
299+
# that some indexes will be set twice and some won't be set.
300+
index = nd_item.get_global_linear_id()
301+
302+
a[index] = 1
303+
304+
305+
@dpex_exp.kernel
306+
def set_local_3d_ones_nd_item_linear(nd_item: NdItem, a):
307+
# Since we have different sizes for each dimension, wrong order will result
308+
# that some indexes will be set twice and some won't be set.
309+
index = nd_item.get_local_linear_id()
310+
311+
a[index] = 1
312+
313+
314+
@pytest.mark.parametrize("kernel", [set_3d_ones_item, set_3d_ones_item_linear])
315+
def test_item_index_order(kernel):
316+
a = dpnp.zeros(I_SIZE * J_SIZE * K_SIZE, dtype=dpnp.int32)
317+
318+
dpex_exp.call_kernel(kernel, dpex.Range(I_SIZE, J_SIZE, K_SIZE), a)
319+
320+
assert np.array_equal(a.asnumpy(), np.ones(a.size, dtype=np.int32))
321+
322+
323+
def test_nd_item_index_order():
288324
a = dpnp.zeros(I_SIZE * J_SIZE * K_SIZE, dtype=dpnp.int32)
289325

290326
dpex_exp.call_kernel(
291-
set_3d_ones_item, dpex.Range(I_SIZE, J_SIZE, K_SIZE), a
327+
set_3d_ones_nd_item_linear,
328+
dpex.NdRange((I_SIZE, J_SIZE, K_SIZE), (1, 1, K_SIZE)),
329+
a,
292330
)
293331

294332
assert np.array_equal(a.asnumpy(), np.ones(a.size, dtype=np.int32))
333+
334+
335+
def test_nd_item_local_linear_id():
336+
a = dpnp.zeros(I_SIZE * J_SIZE * K_SIZE, dtype=dpnp.int32)
337+
338+
dpex_exp.call_kernel(
339+
set_local_3d_ones_nd_item_linear,
340+
dpex.NdRange((I_SIZE, J_SIZE, K_SIZE), (1, 1, K_SIZE)),
341+
a,
342+
)
343+
344+
assert np.array_equal(
345+
a.asnumpy(),
346+
np.array([1] * K_SIZE + [0] * (a.size - K_SIZE), dtype=np.int32),
347+
)

0 commit comments

Comments
 (0)