Skip to content

Commit f78d172

Browse files
Add more tests for 100% coverage of top_k function
1 parent 5f096b8 commit f78d172

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

dpctl/tests/test_usm_ndarray_top_k.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,43 @@ def test_top_k_noncontig():
276276
assert dpt.all(
277277
dpt.sort(r.indices) == dpt.asarray([125, 126, 127])
278278
), r.indices
279+
280+
281+
def test_top_k_axis0():
282+
get_queue_or_skip()
283+
284+
m, n, k = 128, 8, 3
285+
x = dpt.reshape(dpt.arange(m * n, dtype=dpt.int32), (m, n))
286+
287+
r = dpt.top_k(x, k, axis=0, mode="smallest")
288+
assert r.values.shape == (k, n)
289+
assert r.indices.shape == (k, n)
290+
expected_inds = dpt.reshape(dpt.arange(m, dtype=r.indices.dtype), (m, 1))[
291+
:k, :
292+
]
293+
assert dpt.all(
294+
dpt.sort(r.indices, axis=0) == dpt.sort(expected_inds, axis=0)
295+
)
296+
assert dpt.all(dpt.sort(r.values, axis=0) == dpt.sort(x[:k, :], axis=0))
297+
298+
299+
def test_top_k_validation():
300+
get_queue_or_skip()
301+
x = dpt.ones(10, dtype=dpt.int64)
302+
with pytest.raises(ValueError):
303+
# k must be positive
304+
dpt.top_k(x, -1)
305+
with pytest.raises(TypeError):
306+
# argument should be usm_ndarray
307+
dpt.top_k(list(), 2)
308+
x2 = dpt.reshape(x, (2, 5))
309+
with pytest.raises(ValueError):
310+
# k must not exceed array dimension
311+
# along specified axis
312+
dpt.top_k(x2, 100, axis=1)
313+
with pytest.raises(ValueError):
314+
# for 0d arrays, k must be 1
315+
dpt.top_k(x[0], 2)
316+
with pytest.raises(ValueError):
317+
# mode must be "largest", or "smallest"
318+
dpt.top_k(x, 2, mode="invalid")

0 commit comments

Comments
 (0)