Skip to content

Commit 440fb77

Browse files
committed
Refactor diff tests and improve coverage
1 parent 52a8d23 commit 440fb77

File tree

1 file changed

+66
-75
lines changed

1 file changed

+66
-75
lines changed

dpctl/tests/test_tensor_diff.py

Lines changed: 66 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,18 @@ def test_diff_basic(dt):
4848
skip_if_dtype_not_supported(dt, q)
4949

5050
x = dpt.asarray([9, 12, 7, 17, 10, 18, 15, 9, 8, 8], dtype=dt, sycl_queue=q)
51-
res = dpt.diff(x)
5251
op = dpt.not_equal if x.dtype is dpt.bool else dpt.subtract
53-
expected_res = op(x[1:], x[:-1])
54-
if dpt.dtype(dt).kind in "fc":
55-
assert dpt.allclose(res, expected_res)
56-
else:
57-
assert dpt.all(res == expected_res)
5852

59-
res = dpt.diff(x, n=5)
60-
expected_res = x
61-
for _ in range(5):
62-
expected_res = op(expected_res[1:], expected_res[:-1])
63-
if dpt.dtype(dt).kind in "fc":
64-
assert dpt.allclose(res, expected_res)
65-
else:
66-
assert dpt.all(res == expected_res)
53+
# test both n=2 and n>2 branches
54+
for n in [1, 2, 5]:
55+
res = dpt.diff(x, n=n)
56+
expected_res = x
57+
for _ in range(n):
58+
expected_res = op(expected_res[1:], expected_res[:-1])
59+
if dpt.dtype(dt).kind in "fc":
60+
assert dpt.allclose(res, expected_res)
61+
else:
62+
assert dpt.all(res == expected_res)
6763

6864

6965
def test_diff_axis():
@@ -73,17 +69,15 @@ def test_diff_axis():
7369
dpt.asarray([9, 12, 7, 17, 10, 18, 15, 9, 8, 8], dtype="i4"), (3, 4, 1)
7470
)
7571
x[:, ::2, :] = 0
76-
res = dpt.diff(x, n=1, axis=1)
77-
expected_res = dpt.subtract(x[:, 1:, :], x[:, :-1, :])
78-
assert dpt.all(res == expected_res)
79-
80-
res = dpt.diff(x, n=3, axis=1)
81-
expected_res = x
82-
for _ in range(3):
83-
expected_res = dpt.subtract(
84-
expected_res[:, 1:, :], expected_res[:, :-1, :]
85-
)
86-
assert dpt.all(res == expected_res)
72+
73+
for n in [1, 2, 3]:
74+
res = dpt.diff(x, n=3, axis=1)
75+
expected_res = x
76+
for _ in range(3):
77+
expected_res = dpt.subtract(
78+
expected_res[:, 1:, :], expected_res[:, :-1, :]
79+
)
80+
assert dpt.all(res == expected_res)
8781

8882

8983
def test_diff_prepend_append_type_promotion():
@@ -179,33 +173,28 @@ def test_diff_prepend_append_py_scalars(sh, axis):
179173
sl2 = tuple(sl2)
180174

181175
r = dpt.diff(arr, axis=axis, prepend=zero, append=zero)
182-
assert isinstance(r, dpt.usm_ndarray)
183176
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
184177
assert r.shape[axis] == arr.shape[axis] + 2 - n
185178
assert dpt.all(r[sl1] == 1)
186179
assert dpt.all(r[sl2] == -1)
187180

188181
r = dpt.diff(arr, axis=axis, prepend=zero)
189-
assert isinstance(r, dpt.usm_ndarray)
190182
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
191183
assert r.shape[axis] == arr.shape[axis] + 1 - n
192184
assert dpt.all(r[sl1] == 1)
193185

194186
r = dpt.diff(arr, axis=axis, append=zero)
195-
assert isinstance(r, dpt.usm_ndarray)
196187
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
197188
assert r.shape[axis] == arr.shape[axis] + 1 - n
198189
assert dpt.all(r[sl2] == -1)
199190

200191
r = dpt.diff(arr, axis=axis, prepend=dpt.asarray(zero), append=zero)
201-
assert isinstance(r, dpt.usm_ndarray)
202192
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
203193
assert r.shape[axis] == arr.shape[axis] + 2 - n
204194
assert dpt.all(r[sl1] == 1)
205195
assert dpt.all(r[sl2] == -1)
206196

207197
r = dpt.diff(arr, axis=axis, prepend=zero, append=dpt.asarray(zero))
208-
assert isinstance(r, dpt.usm_ndarray)
209198
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
210199
assert r.shape[axis] == arr.shape[axis] + 2 - n
211200
assert dpt.all(r[sl1] == 1)
@@ -218,54 +207,36 @@ def test_tensor_diff_append_prepend_arrays():
218207
n = 1
219208
axis = 0
220209

221-
sz = 5
222-
arr = dpt.arange(sz, 2 * sz, dtype="i4")
223-
prepend = dpt.arange(sz, dtype="i4")
224-
append = dpt.arange(2 * sz, 3 * sz, dtype="i4")
225-
const_diff = 1
226-
227-
r = dpt.diff(arr, axis=axis, prepend=prepend, append=append)
228-
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
229-
assert (
230-
r.shape[axis]
231-
== arr.shape[axis] + prepend.shape[axis] + append.shape[axis] - n
232-
)
233-
assert dpt.all(r == const_diff)
234-
235-
r = dpt.diff(arr, axis=axis, prepend=prepend)
236-
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
237-
assert r.shape[axis] == arr.shape[axis] + prepend.shape[axis] - n
238-
assert dpt.all(r == const_diff)
239-
240-
r = dpt.diff(arr, axis=axis, append=append)
241-
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
242-
assert r.shape[axis] == arr.shape[axis] + append.shape[axis] - n
243-
assert dpt.all(r == const_diff)
244-
245-
sh = (3, 4, 5)
246-
sz = prod(sh)
247-
arr = dpt.reshape(dpt.arange(sz, 2 * sz, dtype="i4"), sh)
248-
prepend = dpt.reshape(dpt.arange(sz, dtype="i4"), sh)
249-
append = dpt.reshape(dpt.arange(2 * sz, 3 * sz, dtype="i4"), sh)
250-
const_diff = prod(sh[axis + 1 :])
210+
for sh in [(5,), (3, 4, 5)]:
211+
sz = prod(sh)
212+
arr = dpt.reshape(dpt.arange(sz, 2 * sz, dtype="i4"), sh)
213+
prepend = dpt.reshape(dpt.arange(sz, dtype="i4"), sh)
214+
append = dpt.reshape(dpt.arange(2 * sz, 3 * sz, dtype="i4"), sh)
215+
const_diff = sz / sh[axis]
251216

252-
r = dpt.diff(arr, axis=axis, prepend=prepend, append=append)
253-
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
254-
assert (
255-
r.shape[axis]
256-
== arr.shape[axis] + prepend.shape[axis] + append.shape[axis] - n
257-
)
258-
assert dpt.all(r == const_diff)
217+
r = dpt.diff(arr, axis=axis, prepend=prepend, append=append)
218+
assert all(
219+
r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis
220+
)
221+
assert (
222+
r.shape[axis]
223+
== arr.shape[axis] + prepend.shape[axis] + append.shape[axis] - n
224+
)
225+
assert dpt.all(r == const_diff)
259226

260-
r = dpt.diff(arr, axis=axis, prepend=prepend)
261-
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
262-
assert r.shape[axis] == arr.shape[axis] + prepend.shape[axis] - n
263-
assert dpt.all(r == const_diff)
227+
r = dpt.diff(arr, axis=axis, prepend=prepend)
228+
assert all(
229+
r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis
230+
)
231+
assert r.shape[axis] == arr.shape[axis] + prepend.shape[axis] - n
232+
assert dpt.all(r == const_diff)
264233

265-
r = dpt.diff(arr, axis=axis, append=append)
266-
assert all(r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis)
267-
assert r.shape[axis] == arr.shape[axis] + append.shape[axis] - n
268-
assert dpt.all(r == const_diff)
234+
r = dpt.diff(arr, axis=axis, append=append)
235+
assert all(
236+
r.shape[i] == arr.shape[i] for i in range(arr.ndim) if i != axis
237+
)
238+
assert r.shape[axis] == arr.shape[axis] + append.shape[axis] - n
239+
assert dpt.all(r == const_diff)
269240

270241

271242
def test_diff_wrong_append_prepend_shape():
@@ -332,6 +303,26 @@ def test_diff_compute_follows_data():
332303
append=ar3,
333304
)
334305

306+
assert_raises_regex(
307+
ExecutionPlacementError,
308+
"Execution placement can not be unambiguously inferred from input "
309+
"arguments",
310+
dpt.diff,
311+
ar1,
312+
prepend=ar2,
313+
append=0,
314+
)
315+
316+
assert_raises_regex(
317+
ExecutionPlacementError,
318+
"Execution placement can not be unambiguously inferred from input "
319+
"arguments",
320+
dpt.diff,
321+
ar1,
322+
prepend=0,
323+
append=ar2,
324+
)
325+
335326
assert_raises_regex(
336327
ExecutionPlacementError,
337328
"Execution placement can not be unambiguously inferred from input "

0 commit comments

Comments
 (0)