@@ -45,16 +45,16 @@ def test_topk_1d_largest(dtype, n):
4545
4646 o = dpt .ones (n , dtype = dtype )
4747 z = dpt .zeros (n , dtype = dtype )
48- zo = dpt .concat ((o , z ))
49- inp = dpt .roll (zo , 734 )
48+ oz = dpt .concat ((o , z ))
49+ inp = dpt .roll (oz , 734 )
5050 k = 5
5151
5252 s = dpt .top_k (inp , k , mode = "largest" )
5353 assert s .values .shape == (k ,)
5454 assert s .values .dtype == inp .dtype
5555 assert s .indices .shape == (k ,)
56- assert dpt .all (s .values == dpt .ones (k , dtype = dtype ))
57- assert dpt .all (s .values == inp [s .indices ])
56+ assert dpt .all (s .values == dpt .ones (k , dtype = dtype )), s . values
57+ assert dpt .all (s .values == inp [s .indices ]), s . indices
5858
5959
6060@pytest .mark .parametrize (
@@ -82,34 +82,38 @@ def test_topk_1d_smallest(dtype, n):
8282
8383 o = dpt .ones (n , dtype = dtype )
8484 z = dpt .zeros (n , dtype = dtype )
85- zo = dpt .concat ((o , z ))
86- inp = dpt .roll (zo , 734 )
85+ oz = dpt .concat ((o , z ))
86+ inp = dpt .roll (oz , 734 )
8787 k = 5
8888
8989 s = dpt .top_k (inp , k , mode = "smallest" )
9090 assert s .values .shape == (k ,)
9191 assert s .values .dtype == inp .dtype
9292 assert s .indices .shape == (k ,)
93- assert dpt .all (s .values == dpt .zeros (k , dtype = dtype ))
94- assert dpt .all (s .values == inp [s .indices ])
93+ assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s . values
94+ assert dpt .all (s .values == inp [s .indices ]), s . indices
9595
9696
9797# triage failing top k radix implementation on CPU
9898# replicates from Python behavior of radix sort topk implementation
9999@pytest .mark .parametrize ("n" , [33 , 255 , 511 , 1021 , 8193 ])
100- def test_topk_largest_1d_radix_i1_255 (n ):
100+ def test_topk_largest_1d_radix_i1 (n ):
101101 get_queue_or_skip ()
102102 dt = "i1"
103103
104104 o = dpt .ones (n , dtype = dt )
105105 z = dpt .zeros (n , dtype = dt )
106- zo = dpt .concat ((o , z ))
107- inp = dpt .roll (zo , 734 )
106+ oz = dpt .concat ((o , z ))
107+ inp = dpt .roll (oz , 734 )
108108 k = 5
109109
110- sorted = dpt .copy (dpt .sort (inp , descending = True , kind = "radixsort" )[:k ])
111- argsorted = dpt .copy (
112- dpt .argsort (inp , descending = True , kind = "radixsort" )[:k ]
113- )
114- assert dpt .all (sorted == dpt .ones (k , dtype = dt ))
115- assert dpt .all (sorted == inp [argsorted ])
110+ sorted_v = dpt .sort (inp , descending = True , kind = "radixsort" )
111+ argsorted = dpt .argsort (inp , descending = True , kind = "radixsort" )
112+
113+ assert dpt .all (sorted_v == inp [argsorted ])
114+
115+ topk_vals = dpt .copy (sorted_v [:k ])
116+ topk_inds = dpt .copy (argsorted [:k ])
117+
118+ assert dpt .all (topk_vals == dpt .ones (k , dtype = dt ))
119+ assert dpt .all (topk_vals == inp [topk_inds ]), topk_inds
0 commit comments