@@ -276,3 +276,43 @@ def test_top_k_noncontig():
276
276
assert dpt .all (
277
277
dpt .sort (r .indices ) == dpt .asarray ([125 , 126 , 127 ])
278
278
), r .indices
279
+
280
+
281
+ def test_top_k_axis0 ():
282
+ get_queue_or_skip ()
283
+
284
+ m , n , k = 128 , 8 , 3
285
+ x = dpt .reshape (dpt .arange (m * n , dtype = dpt .int32 ), (m , n ))
286
+
287
+ r = dpt .top_k (x , k , axis = 0 , mode = "smallest" )
288
+ assert r .values .shape == (k , n )
289
+ assert r .indices .shape == (k , n )
290
+ expected_inds = dpt .reshape (dpt .arange (m , dtype = r .indices .dtype ), (m , 1 ))[
291
+ :k , :
292
+ ]
293
+ assert dpt .all (
294
+ dpt .sort (r .indices , axis = 0 ) == dpt .sort (expected_inds , axis = 0 )
295
+ )
296
+ assert dpt .all (dpt .sort (r .values , axis = 0 ) == dpt .sort (x [:k , :], axis = 0 ))
297
+
298
+
299
+ def test_top_k_validation ():
300
+ get_queue_or_skip ()
301
+ x = dpt .ones (10 , dtype = dpt .int64 )
302
+ with pytest .raises (ValueError ):
303
+ # k must be positive
304
+ dpt .top_k (x , - 1 )
305
+ with pytest .raises (TypeError ):
306
+ # argument should be usm_ndarray
307
+ dpt .top_k (list (), 2 )
308
+ x2 = dpt .reshape (x , (2 , 5 ))
309
+ with pytest .raises (ValueError ):
310
+ # k must not exceed array dimension
311
+ # along specified axis
312
+ dpt .top_k (x2 , 100 , axis = 1 )
313
+ with pytest .raises (ValueError ):
314
+ # for 0d arrays, k must be 1
315
+ dpt .top_k (x [0 ], 2 )
316
+ with pytest .raises (ValueError ):
317
+ # mode must be "largest", or "smallest"
318
+ dpt .top_k (x , 2 , mode = "invalid" )
0 commit comments