@@ -276,3 +276,43 @@ def test_top_k_noncontig():
276276 assert dpt .all (
277277 dpt .sort (r .indices ) == dpt .asarray ([125 , 126 , 127 ])
278278 ), r .indices
279+
280+
281+ def test_top_k_axis0 ():
282+ get_queue_or_skip ()
283+
284+ m , n , k = 128 , 8 , 3
285+ x = dpt .reshape (dpt .arange (m * n , dtype = dpt .int32 ), (m , n ))
286+
287+ r = dpt .top_k (x , k , axis = 0 , mode = "smallest" )
288+ assert r .values .shape == (k , n )
289+ assert r .indices .shape == (k , n )
290+ expected_inds = dpt .reshape (dpt .arange (m , dtype = r .indices .dtype ), (m , 1 ))[
291+ :k , :
292+ ]
293+ assert dpt .all (
294+ dpt .sort (r .indices , axis = 0 ) == dpt .sort (expected_inds , axis = 0 )
295+ )
296+ assert dpt .all (dpt .sort (r .values , axis = 0 ) == dpt .sort (x [:k , :], axis = 0 ))
297+
298+
299+ def test_top_k_validation ():
300+ get_queue_or_skip ()
301+ x = dpt .ones (10 , dtype = dpt .int64 )
302+ with pytest .raises (ValueError ):
303+ # k must be positive
304+ dpt .top_k (x , - 1 )
305+ with pytest .raises (TypeError ):
306+ # argument should be usm_ndarray
307+ dpt .top_k (list (), 2 )
308+ x2 = dpt .reshape (x , (2 , 5 ))
309+ with pytest .raises (ValueError ):
310+ # k must not exceed array dimension
311+ # along specified axis
312+ dpt .top_k (x2 , 100 , axis = 1 )
313+ with pytest .raises (ValueError ):
314+ # for 0d arrays, k must be 1
315+ dpt .top_k (x [0 ], 2 )
316+ with pytest .raises (ValueError ):
317+ # mode must be "largest", or "smallest"
318+ dpt .top_k (x , 2 , mode = "invalid" )
0 commit comments