Skip to content

Commit 6b483fd

Browse files
author
Diptorup Deb
authored
Merge pull request #1285 from chudur-budur/fix/test-strided-kernel-dpnp-array
Refactor all tests for strided dpnp arrays in kernel with different layouts
2 parents 12ebcd3 + 52f4ed1 commit 6b483fd

File tree

1 file changed

+228
-55
lines changed

1 file changed

+228
-55
lines changed

numba_dpex/tests/experimental/test_strided_dpnp_array_in_kernel.py

Lines changed: 228 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,108 +2,281 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import math
6+
57
import dpnp
8+
import numpy as np
9+
import pytest
610

711
import numba_dpex
812
import numba_dpex.experimental as exp_dpex
9-
from numba_dpex import NdRange, Range
13+
from numba_dpex import Range
14+
15+
16+
def get_order(a):
17+
"""Get order of an array.
18+
19+
Args:
20+
a (numpy.ndarray, dpnp.ndarray): Input array.
21+
22+
Raises:
23+
Exception: _description_
24+
25+
Returns:
26+
str: 'C' if c-contiguous, 'F' if f-contiguous or 'A' if aligned.
27+
"""
28+
if a.flags.c_contiguous and not a.flags.f_contiguous:
29+
return "C"
30+
elif not a.flags.c_contiguous and a.flags.f_contiguous:
31+
return "F"
32+
elif a.flags.c_contiguous and a.flags.f_contiguous:
33+
return "A"
34+
else:
35+
raise Exception("Unknown order/layout")
1036

1137

1238
@exp_dpex.kernel
1339
def change_values_1d(x, v):
40+
"""Assign values in a 1d dpnp.ndarray
41+
42+
Args:
43+
x (dpnp.ndarray): Input array.
44+
v (int): Value to be assigned.
45+
"""
1446
i = numba_dpex.get_global_id(0)
15-
p = x[i] # getitem
16-
p = v
17-
x[i] = p # setitem
47+
x[i] = v
48+
49+
50+
def change_values_1d_func(a, p):
51+
"""Assign values in a 1d numpy.ndarray
52+
53+
Args:
54+
a (numpy.ndarray): Input array.
55+
p (int): Value to be assigned.
56+
"""
57+
for i in range(a.shape[0]):
58+
a[i] = p
1859

1960

2061
@exp_dpex.kernel
2162
def change_values_2d(x, v):
63+
"""Assign values in a 2d dpnp.ndarray
64+
65+
Args:
66+
x (dpnp.ndarray): Input array.
67+
v (int): Value to be assigned.
68+
"""
2269
i = numba_dpex.get_global_id(0)
2370
j = numba_dpex.get_global_id(1)
24-
p = x[i, j] # getitem
25-
p = v
26-
x[i, j] = p # setitem
71+
x[i, j] = v
72+
73+
74+
def change_values_2d_func(a, p):
75+
"""Assign values in a 2d numpy.ndarray
76+
77+
Args:
78+
a (numpy.ndarray): Input array.
79+
p (int): Value to be assigned.
80+
"""
81+
for i in range(a.shape[0]):
82+
for j in range(a.shape[1]):
83+
a[i, j] = p
2784

2885

2986
@exp_dpex.kernel
3087
def change_values_3d(x, v):
88+
"""Assign values in a 3d dpnp.ndarray
89+
90+
Args:
91+
x (dpnp.ndarray): Input array.
92+
v (int): Value to be assigned.
93+
"""
3194
i = numba_dpex.get_global_id(0)
3295
j = numba_dpex.get_global_id(1)
3396
k = numba_dpex.get_global_id(2)
34-
p = x[i, j, k] # getitem
35-
p = v
36-
x[i, j, k] = p # setitem
97+
x[i, j, k] = v
3798

3899

39-
def test_strided_dpnp_array_in_kernel():
100+
def change_values_3d_func(a, p):
101+
"""Assign values in a 3d numpy.ndarray
102+
103+
Args:
104+
a (numpy.ndarray): Input array.
105+
p (int): Value to be assigned.
106+
"""
107+
for i in range(a.shape[0]):
108+
for j in range(a.shape[1]):
109+
for k in range(a.shape[2]):
110+
a[i, j, k] = p
111+
112+
113+
@pytest.mark.parametrize("s", [1, 2, 3, 4, 5, 6, 7])
114+
def test_1d_strided_dpnp_array_in_kernel(s):
40115
"""
41116
Tests if we can correctly handle a strided 1d dpnp array
42117
inside dpex kernel.
43118
"""
44-
N = 1024
45-
out = dpnp.arange(0, N * 2, dtype=dpnp.int64)
46-
b = out[::2]
119+
N = 256
120+
k = -3
121+
122+
t = np.arange(0, N, dtype=dpnp.int64)
123+
u = dpnp.asarray(t)
124+
125+
v = u[::s]
126+
exp_dpex.call_kernel(change_values_1d, Range(v.shape[0]), v, k)
47127

48-
r = Range(N)
49-
v = -3
50-
exp_dpex.call_kernel(change_values_1d, r, b, v)
128+
x = t[::s]
129+
change_values_1d_func(x, k)
51130

52-
assert (dpnp.asnumpy(b) == v).all()
131+
# check the value of the array view
132+
assert np.all(dpnp.asnumpy(v) == x)
133+
# check the value of the original arrays
134+
assert np.all(dpnp.asnumpy(u) == t)
53135

54136

55-
def test_multievel_strided_dpnp_array_in_kernel():
137+
@pytest.mark.parametrize("s", [2, 3, 4, 5])
138+
def test_multievel_1d_strided_dpnp_array_in_kernel(s):
56139
"""
57140
Tests if we can correctly handle a multilevel strided 1d dpnp array
58141
inside dpex kernel.
59142
"""
60-
N = 128
61-
out = dpnp.arange(0, N * 2, dtype=dpnp.int64)
62-
v = -3
143+
N = 256
144+
k = -3
145+
146+
t = dpnp.arange(0, N, dtype=dpnp.int64)
147+
u = dpnp.asarray(t)
148+
149+
v, x = u, t
150+
while v.shape[0] > 1:
151+
v = v[::s]
152+
exp_dpex.call_kernel(change_values_1d, Range(v.shape[0]), v, k)
153+
154+
x = x[::s]
155+
change_values_1d_func(x, k)
63156

64-
b = out
65-
n = N
66-
K = 7
67-
for _ in range(K):
68-
b = b[::2]
69-
exp_dpex.call_kernel(change_values_1d, Range(n), b, v)
70-
assert (dpnp.asnumpy(b) == v).all()
71-
n = int(n / 2)
157+
# check the value of the array view
158+
assert np.all(dpnp.asnumpy(v) == x)
159+
# check the value of the original arrays
160+
assert np.all(dpnp.asnumpy(u) == t)
72161

73162

74-
def test_multilevel_2d_strided_dpnp_array_in_kernel():
163+
@pytest.mark.parametrize("s1", [2, 4, 6, 8])
164+
@pytest.mark.parametrize("s2", [1, 3, 5, 7])
165+
@pytest.mark.parametrize("order", ["C", "F"])
166+
def test_2d_strided_dpnp_array_in_kernel(s1, s2, order):
167+
"""
168+
Tests if we can correctly handle a strided 2d dpnp array
169+
inside dpex kernel.
170+
"""
171+
M, N = 13, 31
172+
k = -3
173+
174+
t = np.arange(0, M * N, dtype=np.int64).reshape(M, N, order=order)
175+
u = dpnp.asarray(t)
176+
177+
# check order, sanity check
178+
assert get_order(u) == order
179+
180+
v = u[::s1, ::s2]
181+
exp_dpex.call_kernel(change_values_2d, Range(*v.shape), v, k)
182+
183+
x = t[::s1, ::s2]
184+
change_values_2d_func(x, k)
185+
186+
# check the value of the array view
187+
assert np.all(dpnp.asnumpy(v) == x)
188+
# check the value of the original arrays
189+
assert np.all(dpnp.asnumpy(u) == t)
190+
191+
192+
@pytest.mark.parametrize("s1", [2, 4, 6, 8])
193+
@pytest.mark.parametrize("s2", [3, 5, 7, 9])
194+
@pytest.mark.parametrize("order", ["C", "F"])
195+
def test_multilevel_2d_strided_dpnp_array_in_kernel(s1, s2, order):
75196
"""
76197
Tests if we can correctly handle a multilevel strided 2d dpnp array
77198
inside dpex kernel.
78199
"""
79-
N = 128
80-
out, _ = dpnp.mgrid[0 : N * 2, 0 : N * 2] # noqa: E203
81-
v = -3
200+
M, N = 13, 31
201+
k = -3
202+
203+
t = np.arange(0, M * N, dtype=np.int64).reshape(M, N, order=order)
204+
u = dpnp.asarray(t)
205+
206+
# check order, sanity check
207+
assert get_order(u) == order
82208

83-
b = out
84-
n = N
85-
K = 7
86-
for _ in range(K):
87-
b = b[::2, ::2]
88-
exp_dpex.call_kernel(change_values_2d, Range(n, n), b, v)
89-
assert (dpnp.asnumpy(b) == v).all()
90-
n = int(n / 2)
209+
v, x = u, t
210+
while v.shape[0] > 1 and v.shape[1] > 1:
211+
v = v[::s1, ::s2]
212+
exp_dpex.call_kernel(change_values_2d, Range(*v.shape), v, k)
91213

214+
x = x[::s1, ::s2]
215+
change_values_2d_func(x, k)
92216

93-
def test_multilevel_3d_strided_dpnp_array_in_kernel():
217+
# check the value of the array view
218+
assert np.all(dpnp.asnumpy(v) == x)
219+
# check the value of the original arrays
220+
assert np.all(dpnp.asnumpy(u) == t)
221+
222+
223+
@pytest.mark.parametrize("s1", [1, 2, 3])
224+
@pytest.mark.parametrize("s2", [2, 3, 4])
225+
@pytest.mark.parametrize("s3", [3, 4, 5])
226+
@pytest.mark.parametrize("order", ["C", "F"])
227+
def test_3d_strided_dpnp_array_in_kernel(s1, s2, s3, order):
228+
"""
229+
Tests if we can correctly handle a strided 3d dpnp array
230+
inside dpex kernel.
231+
"""
232+
M, N, K = 13, 31, 11
233+
k = -3
234+
235+
t = np.arange(0, M * N * K, dtype=np.int64).reshape((M, N, K), order=order)
236+
u = dpnp.asarray(t)
237+
238+
# check order, sanity check
239+
assert get_order(u) == order
240+
241+
v = u[::s1, ::s2, ::s3]
242+
exp_dpex.call_kernel(change_values_3d, Range(*v.shape), v, k)
243+
244+
x = t[::s1, ::s2, ::s3]
245+
change_values_3d_func(x, k)
246+
247+
# check the value of the array view
248+
assert np.all(dpnp.asnumpy(v) == x)
249+
# check the value of the original arrays
250+
assert np.all(dpnp.asnumpy(u) == t)
251+
252+
253+
@pytest.mark.parametrize("s1", [2, 3, 4])
254+
@pytest.mark.parametrize("s2", [3, 4, 5])
255+
@pytest.mark.parametrize("s3", [4, 5, 6])
256+
@pytest.mark.parametrize("order", ["C", "F"])
257+
def test_multilevel_3d_strided_dpnp_array_in_kernel(s1, s2, s3, order):
94258
"""
95259
Tests if we can correctly handle a multilevel strided 3d dpnp array
96260
inside dpex kernel.
97261
"""
98-
N = 128
99-
out, _, _ = dpnp.mgrid[0 : N * 2, 0 : N * 2, 0 : N * 2] # noqa: E203
100-
v = -3
101-
102-
b = out
103-
n = N
104-
K = 7
105-
for _ in range(K):
106-
b = b[::2, ::2, ::2]
107-
exp_dpex.call_kernel(change_values_3d, Range(n, n, n), b, v)
108-
assert (dpnp.asnumpy(b) == v).all()
109-
n = int(n / 2)
262+
M, N, K = 13, 31, 11
263+
k = -3
264+
265+
t = np.arange(0, M * N * K, dtype=np.int64).reshape((M, N, K), order=order)
266+
u = dpnp.asarray(t)
267+
268+
# check order, sanity check
269+
assert get_order(u) == order
270+
271+
v, x = u, t
272+
while v.shape[0] > 1 and v.shape[1] > 1 and v.shape[2] > 1:
273+
v = v[::s1, ::s2, ::s3]
274+
exp_dpex.call_kernel(change_values_3d, Range(*v.shape), v, k)
275+
276+
x = x[::s1, ::s2, ::s3]
277+
change_values_3d_func(x, k)
278+
279+
# check the value of the array view
280+
assert np.all(dpnp.asnumpy(v) == x)
281+
# check the value of the original arrays
282+
assert np.all(dpnp.asnumpy(u) == t)

0 commit comments

Comments
 (0)