Skip to content

Commit f953ca2

Browse files
committed
Register index range methods as jitable
1 parent 86eec7f commit f953ca2

File tree

2 files changed

+112
-20
lines changed

2 files changed

+112
-20
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ def register_jitable_method(type_, method):
301301

302302

303303
register_jitable_method(ItemType, Item.get_linear_id)
304+
register_jitable_method(ItemType, Item.get_linear_range)
304305
register_jitable_method(NdItemType, NdItem.get_global_linear_id)
306+
register_jitable_method(NdItemType, NdItem.get_global_linear_range)
307+
register_jitable_method(NdItemType, NdItem.get_local_linear_range)
305308
register_jitable_method(NdItemType, NdItem.get_local_linear_id)
306309
register_jitable_method(GroupType, Group.get_group_linear_id)
310+
register_jitable_method(GroupType, Group.get_group_linear_range)
311+
register_jitable_method(GroupType, Group.get_local_linear_range)

numba_dpex/tests/experimental/test_index_space_ids.py

Lines changed: 107 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def set_last_one_item(item: Item, a):
3535
a[i] = 1
3636

3737

38+
@dpex_exp.kernel
39+
def set_last_one_linear_item(item: Item, a):
40+
i = item.get_linear_range() - 1
41+
a[i] = 1
42+
43+
44+
@dpex_exp.kernel
45+
def set_last_one_linear_nd_item(nd_item: NdItem, a):
46+
i = nd_item.get_global_linear_range() - 1
47+
a[0] = i
48+
a[i] = 1
49+
50+
3851
@dpex_exp.kernel
3952
def set_last_one_nd_item(item: NdItem, a):
4053
if item.get_global_id(0) == 0:
@@ -43,6 +56,20 @@ def set_last_one_nd_item(item: NdItem, a):
4356
a[i] = 1
4457

4558

59+
@dpex_exp.kernel
60+
def set_last_group_one_linear_nd_item(nd_item: NdItem, a):
61+
i = nd_item.get_local_linear_range() - 1
62+
a[0] = i
63+
a[i] = 1
64+
65+
66+
@dpex_exp.kernel
67+
def set_last_group_one_group_linear_nd_item(nd_item: NdItem, a):
68+
i = nd_item.get_group().get_local_linear_range() - 1
69+
a[0] = i
70+
a[i] = 1
71+
72+
4673
@dpex_exp.kernel
4774
def set_last_group_one_nd_item(item: NdItem, a):
4875
if item.get_global_id(0) == 0:
@@ -99,6 +126,12 @@ def _get_group_range_driver(nditem: NdItem, a):
99126
a[i] = g.get_group_range(0)
100127

101128

129+
def _get_group_linear_range_driver(nditem: NdItem, a):
130+
i = nditem.get_global_linear_id()
131+
g = nditem.get_group()
132+
a[i] = g.get_group_linear_range()
133+
134+
102135
def _get_group_local_range_driver(nditem: NdItem, a):
103136
i = nditem.get_global_id(0)
104137
g = nditem.get_group()
@@ -122,11 +155,34 @@ def test_item_get_range():
122155
assert np.array_equal(a.asnumpy(), want)
123156

124157

125-
def test_nd_item_get_global_range():
158+
@pytest.mark.parametrize(
159+
"rng",
160+
[dpex.Range(_SIZE), dpex.Range(1, _GROUP_SIZE, int(_SIZE / _GROUP_SIZE))],
161+
)
162+
def test_item_get_linear_range(rng):
126163
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
127-
dpex_exp.call_kernel(
128-
set_last_one_nd_item, dpex.NdRange((a.size,), (_GROUP_SIZE,)), a
129-
)
164+
dpex_exp.call_kernel(set_last_one_linear_item, rng, a)
165+
166+
want = np.zeros(a.size, dtype=np.float32)
167+
want[-1] = 1
168+
169+
assert np.array_equal(a.asnumpy(), want)
170+
171+
172+
@pytest.mark.parametrize(
173+
"kernel,rng",
174+
[
175+
(set_last_one_nd_item, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
176+
(set_last_one_linear_nd_item, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
177+
(
178+
set_last_one_linear_nd_item,
179+
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
180+
),
181+
],
182+
)
183+
def test_nd_item_get_global_range(kernel, rng):
184+
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
185+
dpex_exp.call_kernel(kernel, rng, a)
130186

131187
want = np.zeros(a.size, dtype=np.float32)
132188
want[-1] = 1
@@ -135,11 +191,31 @@ def test_nd_item_get_global_range():
135191
assert np.array_equal(a.asnumpy(), want)
136192

137193

138-
def test_nd_item_get_local_range():
194+
@pytest.mark.parametrize(
195+
"kernel,rng",
196+
[
197+
(set_last_group_one_nd_item, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
198+
(
199+
set_last_group_one_linear_nd_item,
200+
dpex.NdRange((_SIZE,), (_GROUP_SIZE,)),
201+
),
202+
(
203+
set_last_group_one_linear_nd_item,
204+
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
205+
),
206+
(
207+
set_last_group_one_group_linear_nd_item,
208+
dpex.NdRange((_SIZE,), (_GROUP_SIZE,)),
209+
),
210+
(
211+
set_last_group_one_group_linear_nd_item,
212+
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
213+
),
214+
],
215+
)
216+
def test_nd_item_get_local_range(kernel, rng):
139217
a = dpnp.zeros(_SIZE, dtype=dpnp.float32)
140-
dpex_exp.call_kernel(
141-
set_last_group_one_nd_item, dpex.NdRange((a.size,), (_GROUP_SIZE,)), a
142-
)
218+
dpex_exp.call_kernel(kernel, rng, a)
143219

144220
want = np.zeros(a.size, dtype=np.float32)
145221
want[_GROUP_SIZE - 1] = 1
@@ -240,21 +316,32 @@ def test_get_group_id(driver, rng):
240316
assert np.array_equal(ka.asnumpy(), expected)
241317

242318

243-
def test_get_group_range():
244-
global_size = 100
245-
group_size = 20
246-
num_groups = global_size // group_size
319+
@pytest.mark.parametrize(
320+
"driver,rng",
321+
[
322+
(_get_group_range_driver, dpex.NdRange((_SIZE,), (_GROUP_SIZE,))),
323+
(
324+
_get_group_linear_range_driver,
325+
dpex.NdRange((_SIZE,), (_GROUP_SIZE,)),
326+
),
327+
(
328+
_get_group_linear_range_driver,
329+
dpex.NdRange((1, 1, _SIZE), (1, 1, _GROUP_SIZE)),
330+
),
331+
],
332+
)
333+
def test_get_group_range(driver, rng):
334+
num_groups = _SIZE // _GROUP_SIZE
247335

248-
a = dpnp.empty(global_size, dtype=dpnp.int32)
249-
ka = dpnp.empty(global_size, dtype=dpnp.int32)
250-
expected = np.empty(global_size, dtype=np.int32)
251-
ndrange = NdRange((global_size,), (group_size,))
252-
dpex_exp.call_kernel(dpex_exp.kernel(_get_group_range_driver), ndrange, a)
253-
kapi_call_kernel(_get_group_range_driver, ndrange, ka)
336+
a = dpnp.empty(_SIZE, dtype=dpnp.int32)
337+
ka = dpnp.empty(_SIZE, dtype=dpnp.int32)
338+
expected = np.empty(_SIZE, dtype=np.int32)
339+
dpex_exp.call_kernel(dpex_exp.kernel(driver), rng, a)
340+
kapi_call_kernel(driver, rng, ka)
254341

255342
for gid in range(num_groups):
256-
for lid in range(group_size):
257-
expected[gid * group_size + lid] = num_groups
343+
for lid in range(_GROUP_SIZE):
344+
expected[gid * _GROUP_SIZE + lid] = num_groups
258345

259346
assert np.array_equal(a.asnumpy(), expected)
260347
assert np.array_equal(ka.asnumpy(), expected)

0 commit comments

Comments
 (0)