Skip to content

Commit 0a07cc1

Browse files
Add information displayed on failure, renamed variables
1 parent dfb521f commit 0a07cc1

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

dpctl/tests/test_usm_ndarray_top_k.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ def test_topk_1d_largest(dtype, n):
4545

4646
o = dpt.ones(n, dtype=dtype)
4747
z = dpt.zeros(n, dtype=dtype)
48-
zo = dpt.concat((o, z))
49-
inp = dpt.roll(zo, 734)
48+
oz = dpt.concat((o, z))
49+
inp = dpt.roll(oz, 734)
5050
k = 5
5151

5252
s = dpt.top_k(inp, k, mode="largest")
5353
assert s.values.shape == (k,)
5454
assert s.values.dtype == inp.dtype
5555
assert s.indices.shape == (k,)
56-
assert dpt.all(s.values == dpt.ones(k, dtype=dtype))
57-
assert dpt.all(s.values == inp[s.indices])
56+
assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values
57+
assert dpt.all(s.values == inp[s.indices]), s.indices
5858

5959

6060
@pytest.mark.parametrize(
@@ -82,34 +82,38 @@ def test_topk_1d_smallest(dtype, n):
8282

8383
o = dpt.ones(n, dtype=dtype)
8484
z = dpt.zeros(n, dtype=dtype)
85-
zo = dpt.concat((o, z))
86-
inp = dpt.roll(zo, 734)
85+
oz = dpt.concat((o, z))
86+
inp = dpt.roll(oz, 734)
8787
k = 5
8888

8989
s = dpt.top_k(inp, k, mode="smallest")
9090
assert s.values.shape == (k,)
9191
assert s.values.dtype == inp.dtype
9292
assert s.indices.shape == (k,)
93-
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype))
94-
assert dpt.all(s.values == inp[s.indices])
93+
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values
94+
assert dpt.all(s.values == inp[s.indices]), s.indices
9595

9696

9797
# triage failing top k radix implementation on CPU
9898
# replicates from Python behavior of radix sort topk implementation
9999
@pytest.mark.parametrize("n", [33, 255, 511, 1021, 8193])
100-
def test_topk_largest_1d_radix_i1_255(n):
100+
def test_topk_largest_1d_radix_i1(n):
101101
get_queue_or_skip()
102102
dt = "i1"
103103

104104
o = dpt.ones(n, dtype=dt)
105105
z = dpt.zeros(n, dtype=dt)
106-
zo = dpt.concat((o, z))
107-
inp = dpt.roll(zo, 734)
106+
oz = dpt.concat((o, z))
107+
inp = dpt.roll(oz, 734)
108108
k = 5
109109

110-
sorted = dpt.copy(dpt.sort(inp, descending=True, kind="radixsort")[:k])
111-
argsorted = dpt.copy(
112-
dpt.argsort(inp, descending=True, kind="radixsort")[:k]
113-
)
114-
assert dpt.all(sorted == dpt.ones(k, dtype=dt))
115-
assert dpt.all(sorted == inp[argsorted])
110+
sorted_v = dpt.sort(inp, descending=True, kind="radixsort")
111+
argsorted = dpt.argsort(inp, descending=True, kind="radixsort")
112+
113+
assert dpt.all(sorted_v == inp[argsorted])
114+
115+
topk_vals = dpt.copy(sorted_v[:k])
116+
topk_inds = dpt.copy(argsorted[:k])
117+
118+
assert dpt.all(topk_vals == dpt.ones(k, dtype=dt))
119+
assert dpt.all(topk_vals == inp[topk_inds]), topk_inds

0 commit comments

Comments
 (0)