Skip to content

Commit 90ab8ae

Browse files
ZzEeKkAaDiptorup Deb
authored andcommitted
Use item object for experimental kernels in tests
1 parent 99c4aa1 commit 90ab8ae

10 files changed

+84
-76
lines changed

numba_dpex/tests/experimental/codegen/test_inline_threshold_codegen.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import dpctl
66
from numba.core import types
77

8-
import numba_dpex as dpex
98
from numba_dpex import DpctlSyclQueue, DpnpNdArray
109
from numba_dpex import experimental as dpex_exp
1110
from numba_dpex import int64
11+
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
12+
from numba_dpex.kernel_api import Item
1213

1314

14-
def kernel_func(a, b, c):
15-
i = dpex.get_global_id(0)
15+
def kernel_func(item: Item, a, b, c):
16+
i = item.get_id(0)
1617
c[i] = a[i] + b[i]
1718

1819

@@ -36,7 +37,7 @@ def test_codegen_with_max_inline_threshold():
3637

3738
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
3839
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
39-
kernel_sig = types.void(i64arr_ty, i64arr_ty, i64arr_ty)
40+
kernel_sig = types.void(ItemType(1), i64arr_ty, i64arr_ty, i64arr_ty)
4041

4142
disp = dpex_exp.kernel(inline_threshold=3)(kernel_func)
4243
disp.compile(kernel_sig)
@@ -57,7 +58,7 @@ def test_codegen_without_max_inline_threshold():
5758

5859
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
5960
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
60-
kernel_sig = types.void(i64arr_ty, i64arr_ty, i64arr_ty)
61+
kernel_sig = types.void(ItemType(1), i64arr_ty, i64arr_ty, i64arr_ty)
6162

6263
disp = dpex_exp.kernel(kernel_func)
6364
disp.compile(kernel_sig)
@@ -70,4 +71,4 @@ def test_codegen_without_max_inline_threshold():
7071
if not f.is_declaration:
7172
count_of_non_declaration_type_functions += 1
7273

73-
assert count_of_non_declaration_type_functions == 2
74+
assert count_of_non_declaration_type_functions == 3

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_fetch_phi.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import pytest
77
from numba.core.errors import TypingError
88

9-
import numba_dpex as dpex
109
import numba_dpex.experimental as dpex_exp
11-
from numba_dpex.kernel_api import AtomicRef
10+
from numba_dpex.kernel_api import AtomicRef, Item, Range
1211
from numba_dpex.tests._helper import get_all_dtypes
1312

1413
list_of_supported_dtypes = get_all_dtypes(
@@ -45,8 +44,8 @@ def test_fetch_phi_fn(input_arrays, ref_index, fetch_phi_fn):
4544
"""A test for all fetch_phi atomic functions."""
4645

4746
@dpex_exp.kernel
48-
def _kernel(a, b, ref_index):
49-
i = dpex.get_global_id(0)
47+
def _kernel(item: Item, a, b, ref_index):
48+
i = item.get_id(0)
5049
v = AtomicRef(b, index=ref_index)
5150
getattr(v, fetch_phi_fn)(a[i])
5251

@@ -60,9 +59,9 @@ def _kernel(a, b, ref_index):
6059
# fetch_and, fetch_or, fetch_xor accept only int arguments.
6160
# test for TypingError when float arguments are passed.
6261
with pytest.raises(TypingError):
63-
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
62+
dpex_exp.call_kernel(_kernel, Range(10), a, b, ref_index)
6463
else:
65-
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, ref_index)
64+
dpex_exp.call_kernel(_kernel, Range(10), a, b, ref_index)
6665
# Verify that `a` accumulated at b[ref_index] by kernel
6766
# matches the `a` accumulated at b[ref_index+1] using Python
6867
for i in range(a.size):
@@ -76,8 +75,8 @@ def test_fetch_phi_retval(fetch_phi_fn):
7675
"""A test for all fetch_phi atomic functions."""
7776

7877
@dpex_exp.kernel
79-
def _kernel(a, b, c):
80-
i = dpex.get_global_id(0)
78+
def _kernel(item: Item, a, b, c):
79+
i = item.get_id(0)
8180
v = AtomicRef(b, index=i)
8281
c[i] = getattr(v, fetch_phi_fn)(a[i])
8382

@@ -89,7 +88,7 @@ def _kernel(a, b, c):
8988
b_copy = dpnp.copy(b)
9089
c_copy = dpnp.copy(c)
9190

92-
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b, c)
91+
dpex_exp.call_kernel(_kernel, Range(10), a, b, c)
9392

9493
# Verify if the value returned by fetch_phi kernel
9594
# stored into `c` is same as the value returned
@@ -108,8 +107,8 @@ def test_fetch_phi_diff_types(fetch_phi_fn):
108107
"""
109108

110109
@dpex_exp.kernel
111-
def _kernel(a, b):
112-
i = dpex.get_global_id(0)
110+
def _kernel(item: Item, a, b):
111+
i = item.get_id(0)
113112
v = AtomicRef(b, index=0)
114113
getattr(v, fetch_phi_fn)(a[i])
115114

@@ -118,19 +117,19 @@ def _kernel(a, b):
118117
b = dpnp.zeros(N, dtype=dpnp.int32)
119118

120119
with pytest.raises(TypingError):
121-
dpex_exp.call_kernel(_kernel, dpex.Range(10), a, b)
120+
dpex_exp.call_kernel(_kernel, Range(10), a, b)
122121

123122

124123
@dpex_exp.kernel
125-
def atomic_ref_0(a):
126-
i = dpex.get_global_id(0)
124+
def atomic_ref_0(item: Item, a):
125+
i = item.get_id(0)
127126
v = AtomicRef(a, index=0)
128127
v.fetch_add(a[i + 2])
129128

130129

131130
@dpex_exp.kernel
132-
def atomic_ref_1(a):
133-
i = dpex.get_global_id(0)
131+
def atomic_ref_1(item: Item, a):
132+
i = item.get_id(0)
134133
v = AtomicRef(a, index=1)
135134
v.fetch_add(a[i + 2])
136135

@@ -144,24 +143,24 @@ def test_spirv_compiler_flags_add():
144143
N = 10
145144
a = dpnp.ones(N, dtype=dpnp.float32)
146145

147-
dpex_exp.call_kernel(atomic_ref_0, dpex.Range(N - 2), a)
148-
dpex_exp.call_kernel(atomic_ref_1, dpex.Range(N - 2), a)
146+
dpex_exp.call_kernel(atomic_ref_0, Range(N - 2), a)
147+
dpex_exp.call_kernel(atomic_ref_1, Range(N - 2), a)
149148

150149
assert a[0] == N - 1
151150
assert a[1] == N - 1
152151

153152

154153
@dpex_exp.kernel
155-
def atomic_max_0(a):
156-
i = dpex.get_global_id(0)
154+
def atomic_max_0(item: Item, a):
155+
i = item.get_id(0)
157156
v = AtomicRef(a, index=0)
158157
if i != 0:
159158
v.fetch_max(a[i])
160159

161160

162161
@dpex_exp.kernel
163-
def atomic_max_1(a):
164-
i = dpex.get_global_id(0)
162+
def atomic_max_1(item: Item, a):
163+
i = item.get_id(0)
165164
v = AtomicRef(a, index=0)
166165
if i != 0:
167166
v.fetch_max(a[i])
@@ -177,8 +176,8 @@ def test_spirv_compiler_flags_max():
177176
a = dpnp.arange(N, dtype=dpnp.float32)
178177
b = dpnp.arange(N, dtype=dpnp.float32)
179178

180-
dpex_exp.call_kernel(atomic_max_0, dpex.Range(N), a)
181-
dpex_exp.call_kernel(atomic_max_1, dpex.Range(N), b)
179+
dpex_exp.call_kernel(atomic_max_0, Range(N), a)
180+
dpex_exp.call_kernel(atomic_max_1, Range(N), b)
182181

183182
assert a[0] == N - 1
184183
assert b[0] == N - 1

numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_atomic_ref.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,21 @@
66
import pytest
77
from numba.core.errors import TypingError
88

9-
import numba_dpex as dpex
109
import numba_dpex.experimental as dpex_exp
11-
from numba_dpex.kernel_api import AddressSpace, AtomicRef
10+
from numba_dpex.kernel_api import AddressSpace, AtomicRef, Item, Range
1211

1312

1413
def test_atomic_ref_compilation():
1514
@dpex_exp.kernel
16-
def atomic_ref_kernel(a, b):
17-
i = dpex.get_global_id(0)
15+
def atomic_ref_kernel(item: Item, a, b):
16+
i = item.get_id(0)
1817
v = AtomicRef(b, index=0, address_space=AddressSpace.GLOBAL)
1918
v.fetch_add(a[i])
2019

2120
a = dpnp.ones(10)
2221
b = dpnp.zeros(10)
2322
try:
24-
dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b)
23+
dpex_exp.call_kernel(atomic_ref_kernel, Range(10), a, b)
2524
except Exception:
2625
pytest.fail("Unexpected execution failure")
2726

@@ -33,13 +32,13 @@ def test_atomic_ref_compilation_failure():
3332
"""
3433

3534
@dpex_exp.kernel
36-
def atomic_ref_kernel(a, b):
37-
i = dpex.get_global_id(0)
35+
def atomic_ref_kernel(item: Item, a, b):
36+
i = item.get_id(0)
3837
v = AtomicRef(b, index=0, address_space=AddressSpace.LOCAL)
3938
v.fetch_add(a[i])
4039

4140
a = dpnp.ones(10)
4241
b = dpnp.zeros(10)
4342

4443
with pytest.raises(TypingError):
45-
dpex_exp.call_kernel(atomic_ref_kernel, dpex.Range(10), a, b)
44+
dpex_exp.call_kernel(atomic_ref_kernel, Range(10), a, b)

numba_dpex/tests/experimental/test_async_kernel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
import pytest
88
from numba.core.errors import TypingError
99

10-
import numba_dpex as dpex
1110
import numba_dpex.experimental as exp_dpex
12-
from numba_dpex import Range
1311
from numba_dpex.experimental import testing
12+
from numba_dpex.kernel_api import Item, Range
1413

1514

1615
@exp_dpex.kernel(
@@ -19,8 +18,8 @@
1918
no_cpython_wrapper=True,
2019
no_cfunc_wrapper=True,
2120
)
22-
def add(a, b, c):
23-
i = dpex.get_global_id(0)
21+
def add(item: Item, a, b, c):
22+
i = item.get_id(0)
2423
c[i] = b[i] + a[i]
2524

2625

numba_dpex/tests/experimental/test_compiler_warnings.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import pytest
77
from numba.core import types
88

9-
import numba_dpex as dpex
109
from numba_dpex import DpctlSyclQueue, DpnpNdArray
1110
from numba_dpex import experimental as dpex_exp
1211
from numba_dpex import int64
12+
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
13+
from numba_dpex.kernel_api import Item
1314

1415

15-
def _kernel(a, b, c):
16-
i = dpex.get_global_id(0)
16+
def _kernel(item: Item, a, b, c):
17+
i = item.get_id(0)
1718
c[i] = a[i] + b[i]
1819

1920

@@ -30,5 +31,5 @@ def test_inline_threshold_level_warning():
3031
with pytest.warns(UserWarning):
3132
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
3233
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
33-
kernel_sig = types.void(i64arr_ty, i64arr_ty, i64arr_ty)
34+
kernel_sig = types.void(ItemType(1), i64arr_ty, i64arr_ty, i64arr_ty)
3435
dpex_exp.kernel(inline_threshold=3)(_kernel).compile(kernel_sig)

numba_dpex/tests/experimental/test_inline_threshold_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import numba_dpex as dpex
88
from numba_dpex import experimental as dpex_exp
9+
from numba_dpex.kernel_api import Item
910

1011

11-
def kernel_func(a, b, c):
12-
i = dpex.get_global_id(0)
12+
def kernel_func(item: Item, a, b, c):
13+
i = item.get_id(0)
1314
c[i] = a[i] + b[i]
1415

1516

numba_dpex/tests/experimental/test_kernel_specialization.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,30 @@
77
import pytest
88
from numba.core.errors import TypingError
99

10-
import numba_dpex as dpex
1110
import numba_dpex.experimental as dpex_exp
1211
from numba_dpex import DpnpNdArray, float32, int64
1312
from numba_dpex.core.exceptions import InvalidKernelSpecializationError
14-
from numba_dpex.kernel_api import Range
13+
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
14+
from numba_dpex.kernel_api import Item, Range
1515

1616
i64arrty = DpnpNdArray(ndim=1, dtype=int64, layout="C")
1717
f32arrty = DpnpNdArray(ndim=1, dtype=float32, layout="C")
18+
item_ty = ItemType(ndim=1)
1819

19-
specialized_kernel1 = dpex_exp.kernel((i64arrty, i64arrty, i64arrty))
20+
specialized_kernel1 = dpex_exp.kernel((item_ty, i64arrty, i64arrty, i64arrty))
2021
specialized_kernel2 = dpex_exp.kernel(
21-
[(i64arrty, i64arrty, i64arrty), (f32arrty, f32arrty, f32arrty)]
22+
[
23+
(item_ty, i64arrty, i64arrty, i64arrty),
24+
(item_ty, f32arrty, f32arrty, f32arrty),
25+
]
2226
)
2327

2428

25-
def data_parallel_sum(a, b, c):
29+
def data_parallel_sum(item: Item, a, b, c):
2630
"""
2731
Vector addition using the ``kernel`` decorator.
2832
"""
29-
i = dpex.get_global_id(0)
33+
i = item.get_id(0)
3034
c[i] = a[i] + b[i]
3135

3236

@@ -46,7 +50,9 @@ def test_invalid_specialization_error():
4650
"""Test if an InvalidKernelSpecializationError is raised when attempting to
4751
specialize with NumPy arrays.
4852
"""
49-
specialized_kernel3 = dpex_exp.kernel((int64[::1], int64[::1], int64[::1]))
53+
specialized_kernel3 = dpex_exp.kernel(
54+
(item_ty, int64[::1], int64[::1], int64[::1])
55+
)
5056
with pytest.raises(InvalidKernelSpecializationError):
5157
specialized_kernel3(data_parallel_sum)
5258

@@ -90,11 +96,14 @@ def test_string_specialization():
9096
"""Test if NotImplementedError is raised when signature is a string"""
9197

9298
with pytest.raises(NotImplementedError):
93-
dpex_exp.kernel("(i64arrty, i64arrty, i64arrty)")
99+
dpex_exp.kernel("(item_ty, i64arrty, i64arrty, i64arrty)")
94100

95101
with pytest.raises(NotImplementedError):
96102
dpex_exp.kernel(
97-
["(i64arrty, i64arrty, i64arrty)", "(f32arrty, f32arrty, f32arrty)"]
103+
[
104+
"(item_ty, i64arrty, i64arrty, i64arrty)",
105+
"(item_ty, f32arrty, f32arrty, f32arrty)",
106+
]
98107
)
99108

100109
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)