Skip to content

Commit 3d8c8fe

Browse files
authored
Merge pull request #1368 from IntelPython/feature/add_linear_indexing_overload
Feature/add linear indexing overload
2 parents 7f1880d + f953ca2 commit 3d8c8fe

File tree

3 files changed

+330
-72
lines changed

3 files changed

+330
-72
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ItemType,
1717
NdItemType,
1818
)
19+
from numba_dpex.kernel_api import Group, Item, NdItem
1920
from numba_dpex.kernel_api_impl.spirv.target import SPIRVTargetContext
2021

2122
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
@@ -180,26 +181,31 @@ def ol_item_get_index_impl(item, dim):
180181
return ol_item_gen_index
181182

182183

183-
_index_const_overload_methods = [
184-
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
185-
(ItemType, "get_range", _intrinsic_spirv_global_size),
186-
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
187-
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
188-
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
189-
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
190-
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
191-
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
192-
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
193-
]
184+
def register_index_const_methods():
185+
"""Register indexing related methods that can be defined as spirv const."""
186+
_index_const_overload_methods = [
187+
(ItemType, "get_id", _intrinsic_spirv_global_invocation_id),
188+
(ItemType, "get_range", _intrinsic_spirv_global_size),
189+
(NdItemType, "get_global_id", _intrinsic_spirv_global_invocation_id),
190+
(NdItemType, "get_local_id", _intrinsic_spirv_local_invocation_id),
191+
(NdItemType, "get_global_range", _intrinsic_spirv_global_size),
192+
(NdItemType, "get_local_range", _intrinsic_spirv_workgroup_size),
193+
(GroupType, "get_group_id", _intrinsic_spirv_workgroup_id),
194+
(GroupType, "get_group_range", _intrinsic_spirv_numworkgroups),
195+
(GroupType, "get_local_range", _intrinsic_spirv_workgroup_size),
196+
]
194197

195-
for index_overload in _index_const_overload_methods:
196-
_type, method, _intrinsic = index_overload
198+
for index_overload in _index_const_overload_methods:
199+
_type, method, _intrinsic = index_overload
197200

198-
ol_index_func = generate_index_overload(_type, _intrinsic)
201+
ol_index_func = generate_index_overload(_type, _intrinsic)
199202

200-
overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
201-
ol_index_func
202-
)
203+
overload_method(_type, method, target=DPEX_KERNEL_EXP_TARGET_NAME)(
204+
ol_index_func
205+
)
206+
207+
208+
register_index_const_methods()
203209

204210

205211
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
@@ -269,3 +275,37 @@ def ol_nd_item_get_group_impl(item):
269275
return dimensions
270276

271277
return ol_nd_item_get_group_impl
278+
279+
280+
def _generate_method_overload(method):
281+
"""Generates naive method overload with no argument, except self."""
282+
283+
def ol_method(self): # pylint: disable=unused-argument
284+
return method
285+
286+
return ol_method
287+
288+
289+
def register_jitable_method(type_, method):
290+
"""
291+
Register a regular python method that can be executed by the
292+
python interpreter and can be compiled into a nopython
293+
function when referenced by other jit'ed functions.
294+
295+
Same as register_jitable, but for methods with no arguments.
296+
"""
297+
overloaded_method = _generate_method_overload(method)
298+
overload_method(type_, method.__name__, target=DPEX_KERNEL_EXP_TARGET_NAME)(
299+
overloaded_method
300+
)
301+
302+
303+
register_jitable_method(ItemType, Item.get_linear_id)
304+
register_jitable_method(ItemType, Item.get_linear_range)
305+
register_jitable_method(NdItemType, NdItem.get_global_linear_id)
306+
register_jitable_method(NdItemType, NdItem.get_global_linear_range)
307+
register_jitable_method(NdItemType, NdItem.get_local_linear_range)
308+
register_jitable_method(NdItemType, NdItem.get_local_linear_id)
309+
register_jitable_method(GroupType, Group.get_group_linear_id)
310+
register_jitable_method(GroupType, Group.get_group_linear_range)
311+
register_jitable_method(GroupType, Group.get_local_linear_range)

numba_dpex/kernel_api/index_space_ids.py

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,20 @@ def get_group_id(self, dim):
4242

4343
def get_group_linear_id(self):
4444
"""Returns a linearized version of the work-group index."""
45-
if len(self._index) == 1:
46-
return self._index[0]
47-
if len(self._index) == 2:
48-
return self._index[0] * self._group_range[1] + self._index[1]
45+
if self.dimensions == 1:
46+
return self.get_group_id(0)
47+
if self.dimensions == 2:
48+
return self.get_group_id(0) * self.get_group_range(
49+
1
50+
) + self.get_group_id(1)
4951
return (
50-
(self._index[0] * self._group_range[1] * self._group_range[2])
51-
+ (self._index[1] * self._group_range[2])
52-
+ (self._index[2])
52+
(
53+
self.get_group_id(0)
54+
* self.get_group_range(1)
55+
* self.get_group_range(2)
56+
)
57+
+ (self.get_group_id(1) * self.get_group_range(2))
58+
+ (self.get_group_id(2))
5359
)
5460

5561
def get_group_range(self, dim):
@@ -61,8 +67,8 @@ def get_group_range(self, dim):
6167
def get_group_linear_range(self):
6268
"""Return the total number of work-groups in the nd_range."""
6369
num_wg = 1
64-
for ext in self._group_range:
65-
num_wg *= ext
70+
for i in range(self.dimensions):
71+
num_wg *= self.get_group_range(i)
6672

6773
return num_wg
6874

@@ -76,8 +82,8 @@ def get_local_range(self, dim):
7682
def get_local_linear_range(self):
7783
"""Return the total number of work-items in the work-group."""
7884
num_wi = 1
79-
for ext in self._local_range:
80-
num_wi *= ext
85+
for i in range(self.dimensions):
86+
num_wi *= self.get_local_range(i)
8187

8288
return num_wi
8389

@@ -128,14 +134,14 @@ def get_linear_id(self):
128134
Returns:
129135
int: The linear id.
130136
"""
131-
if len(self._extent) == 1:
132-
return self._index[0]
133-
if len(self._extent) == 2:
134-
return self._index[0] * self._extent[1] + self._index[1]
137+
if self.dimensions == 1:
138+
return self.get_id(0)
139+
if self.dimensions == 2:
140+
return self.get_id(0) * self.get_range(1) + self.get_id(1)
135141
return (
136-
(self._index[0] * self._extent[1] * self._extent[2])
137-
+ (self._index[1] * self._extent[2])
138-
+ (self._index[2])
142+
(self.get_id(0) * self.get_range(1) * self.get_range(2))
143+
+ (self.get_id(1) * self.get_range(2))
144+
+ (self.get_id(2))
139145
)
140146

141147
def get_id(self, idx):
@@ -146,6 +152,14 @@ def get_id(self, idx):
146152
"""
147153
return self._index[idx]
148154

155+
def get_linear_range(self):
156+
"""Return the total number of work-items in the work-group."""
157+
num_wi = 1
158+
for i in range(self.dimensions):
159+
num_wi *= self.get_range(i)
160+
161+
return num_wi
162+
149163
def get_range(self, idx):
150164
"""Get the range size for a specific dimension.
151165
@@ -193,7 +207,24 @@ def get_global_linear_id(self):
193207
Returns:
194208
int: The global linear id.
195209
"""
196-
return self._global_item.get_linear_id()
210+
# Instead of calling self._global_item.get_linear_id(), the linearization
211+
# logic is duplicated here so that the method can be JIT compiled by
212+
# numba-dpex and works in both Python and Numba nopython modes.
213+
if self.dimensions == 1:
214+
return self.get_global_id(0)
215+
if self.dimensions == 2:
216+
return self.get_global_id(0) * self.get_global_range(
217+
1
218+
) + self.get_global_id(1)
219+
return (
220+
(
221+
self.get_global_id(0)
222+
* self.get_global_range(1)
223+
* self.get_global_range(2)
224+
)
225+
+ (self.get_global_id(1) * self.get_global_range(2))
226+
+ (self.get_global_id(2))
227+
)
197228

198229
def get_local_id(self, idx):
199230
"""Get the local id for a specific dimension.
@@ -210,7 +241,24 @@ def get_local_linear_id(self):
210241
Returns:
211242
int: The local linear id.
212243
"""
213-
return self._local_item.get_linear_id()
244+
# Instead of calling self._local_item.get_linear_id(), the linearization
245+
# logic is duplicated here so that the method can be JIT compiled by
246+
# numba-dpex and works in both Python and Numba nopython modes.
247+
if self.dimensions == 1:
248+
return self.get_local_id(0)
249+
if self.dimensions == 2:
250+
return self.get_local_id(0) * self.get_local_range(
251+
1
252+
) + self.get_local_id(1)
253+
return (
254+
(
255+
self.get_local_id(0)
256+
* self.get_local_range(1)
257+
* self.get_local_range(2)
258+
)
259+
+ (self.get_local_id(1) * self.get_local_range(2))
260+
+ (self.get_local_id(2))
261+
)
214262

215263
def get_global_range(self, idx):
216264
"""Get the global range size for a specific dimension.
@@ -228,6 +276,22 @@ def get_local_range(self, idx):
228276
"""
229277
return self._local_item.get_range(idx)
230278

279+
def get_local_linear_range(self):
280+
"""Return the total number of work-items in the work-group."""
281+
num_wi = 1
282+
for i in range(self.dimensions):
283+
num_wi *= self.get_local_range(i)
284+
285+
return num_wi
286+
287+
def get_global_linear_range(self):
288+
"""Return the total number of work-items in the work-group."""
289+
num_wi = 1
290+
for i in range(self.dimensions):
291+
num_wi *= self.get_global_range(i)
292+
293+
return num_wi
294+
231295
def get_group(self):
232296
"""Returns the group.
233297

0 commit comments

Comments
 (0)