1515
1616
1717class TestArgsort :
18+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
1819 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_complex = True ))
19- def test_argsort_dtype (self , dtype ):
20+ def test_basic (self , kind , dtype ):
2021 a = numpy .random .uniform (- 5 , 5 , 10 )
2122 np_array = numpy .array (a , dtype = dtype )
2223 dp_array = dpnp .array (np_array )
2324
24- result = dpnp .argsort (dp_array , kind = "stable" )
25+ result = dpnp .argsort (dp_array , kind = kind )
2526 expected = numpy .argsort (np_array , kind = "stable" )
2627 assert_dtype_allclose (result , expected )
2728
29+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
2830 @pytest .mark .parametrize ("dtype" , get_complex_dtypes ())
29- def test_argsort_complex (self , dtype ):
31+ def test_complex (self , kind , dtype ):
3032 a = numpy .random .uniform (- 5 , 5 , 10 )
3133 b = numpy .random .uniform (- 5 , 5 , 10 )
3234 np_array = numpy .array (a + b * 1j , dtype = dtype )
3335 dp_array = dpnp .array (np_array )
3436
35- result = dpnp .argsort (dp_array )
36- expected = numpy .argsort (np_array )
37- assert_dtype_allclose (result , expected )
37+ if kind == "radixsort" :
38+ assert_raises (ValueError , dpnp .argsort , dp_array , kind = kind )
39+ else :
40+ result = dpnp .argsort (dp_array , kind = kind )
41+ expected = numpy .argsort (np_array )
42+ assert_dtype_allclose (result , expected )
3843
3944 @pytest .mark .parametrize ("axis" , [None , - 2 , - 1 , 0 , 1 , 2 ])
40- def test_argsort_axis (self , axis ):
45+ def test_axis (self , axis ):
4146 a = numpy .random .uniform (- 10 , 10 , 36 )
4247 np_array = numpy .array (a ).reshape (3 , 4 , 3 )
4348 dp_array = dpnp .array (np_array )
@@ -48,7 +53,7 @@ def test_argsort_axis(self, axis):
4853
4954 @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
5055 @pytest .mark .parametrize ("axis" , [None , - 2 , - 1 , 0 , 1 ])
51- def test_argsort_ndarray (self , dtype , axis ):
56+ def test_ndarray (self , dtype , axis ):
5257 if dtype and issubclass (dtype , numpy .integer ):
5358 a = numpy .random .choice (
5459 numpy .arange (- 10 , 10 ), replace = False , size = 12
@@ -62,8 +67,9 @@ def test_argsort_ndarray(self, dtype, axis):
6267 expected = np_array .argsort (axis = axis )
6368 assert_dtype_allclose (result , expected )
6469
65- @pytest .mark .parametrize ("kind" , [None , "stable" ])
66- def test_sort_kind (self , kind ):
70+ # this test validates that all different options of kind in dpnp are stable
71+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
72+ def test_kind (self , kind ):
6773 np_array = numpy .repeat (numpy .arange (10 ), 10 )
6874 dp_array = dpnp .array (np_array )
6975
@@ -74,15 +80,15 @@ def test_sort_kind(self, kind):
7480 # `stable` keyword is supported in numpy 2.0 and above
7581 @testing .with_requires ("numpy>=2.0" )
7682 @pytest .mark .parametrize ("stable" , [None , False , True ])
77- def test_sort_stable (self , stable ):
83+ def test_stable (self , stable ):
7884 np_array = numpy .repeat (numpy .arange (10 ), 10 )
7985 dp_array = dpnp .array (np_array )
8086
8187 result = dpnp .argsort (dp_array , stable = "stable" )
8288 expected = numpy .argsort (np_array , stable = True )
8389 assert_dtype_allclose (result , expected )
8490
85- def test_argsort_zero_dim (self ):
91+ def test_zero_dim (self ):
8692 np_array = numpy .array (2.5 )
8793 dp_array = dpnp .array (np_array )
8894
@@ -266,29 +272,34 @@ def test_v_scalar(self):
266272
267273
268274class TestSort :
275+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
269276 @pytest .mark .parametrize ("dtype" , get_all_dtypes (no_complex = True ))
270- def test_sort_dtype (self , dtype ):
277+ def test_basic (self , kind , dtype ):
271278 a = numpy .random .uniform (- 5 , 5 , 10 )
272279 np_array = numpy .array (a , dtype = dtype )
273280 dp_array = dpnp .array (np_array )
274281
275- result = dpnp .sort (dp_array )
282+ result = dpnp .sort (dp_array , kind = kind )
276283 expected = numpy .sort (np_array )
277284 assert_dtype_allclose (result , expected )
278285
286+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
279287 @pytest .mark .parametrize ("dtype" , get_complex_dtypes ())
280- def test_sort_complex (self , dtype ):
288+ def test_complex (self , kind , dtype ):
281289 a = numpy .random .uniform (- 5 , 5 , 10 )
282290 b = numpy .random .uniform (- 5 , 5 , 10 )
283291 np_array = numpy .array (a + b * 1j , dtype = dtype )
284292 dp_array = dpnp .array (np_array )
285293
286- result = dpnp .sort (dp_array )
287- expected = numpy .sort (np_array )
288- assert_dtype_allclose (result , expected )
294+ if kind == "radixsort" :
295+ assert_raises (ValueError , dpnp .argsort , dp_array , kind = kind )
296+ else :
297+ result = dpnp .sort (dp_array , kind = kind )
298+ expected = numpy .sort (np_array )
299+ assert_dtype_allclose (result , expected )
289300
290301 @pytest .mark .parametrize ("axis" , [None , - 2 , - 1 , 0 , 1 , 2 ])
291- def test_sort_axis (self , axis ):
302+ def test_axis (self , axis ):
292303 a = numpy .random .uniform (- 10 , 10 , 36 )
293304 np_array = numpy .array (a ).reshape (3 , 4 , 3 )
294305 dp_array = dpnp .array (np_array )
@@ -299,7 +310,7 @@ def test_sort_axis(self, axis):
299310
300311 @pytest .mark .parametrize ("dtype" , get_all_dtypes ())
301312 @pytest .mark .parametrize ("axis" , [- 2 , - 1 , 0 , 1 ])
302- def test_sort_ndarray (self , dtype , axis ):
313+ def test_ndarray (self , dtype , axis ):
303314 a = numpy .random .uniform (- 10 , 10 , 12 )
304315 np_array = numpy .array (a , dtype = dtype ).reshape (6 , 2 )
305316 dp_array = dpnp .array (np_array )
@@ -308,8 +319,9 @@ def test_sort_ndarray(self, dtype, axis):
308319 np_array .sort (axis = axis )
309320 assert_dtype_allclose (dp_array , np_array )
310321
311- @pytest .mark .parametrize ("kind" , [None , "stable" ])
312- def test_sort_kind (self , kind ):
322+ # this test validates that all different options of kind in dpnp are stable
323+ @pytest .mark .parametrize ("kind" , [None , "stable" , "mergesort" , "radixsort" ])
324+ def test_kind (self , kind ):
313325 np_array = numpy .repeat (numpy .arange (10 ), 10 )
314326 dp_array = dpnp .array (np_array )
315327
@@ -320,21 +332,21 @@ def test_sort_kind(self, kind):
320332 # `stable` keyword is supported in numpy 2.0 and above
321333 @testing .with_requires ("numpy>=2.0" )
322334 @pytest .mark .parametrize ("stable" , [None , False , True ])
323- def test_sort_stable (self , stable ):
335+ def test_stable (self , stable ):
324336 np_array = numpy .repeat (numpy .arange (10 ), 10 )
325337 dp_array = dpnp .array (np_array )
326338
327339 result = dpnp .sort (dp_array , stable = "stable" )
328340 expected = numpy .sort (np_array , stable = True )
329341 assert_dtype_allclose (result , expected )
330342
331- def test_sort_ndarray_axis_none (self ):
343+ def test_ndarray_axis_none (self ):
332344 a = numpy .random .uniform (- 10 , 10 , 12 )
333345 dp_array = dpnp .array (a ).reshape (6 , 2 )
334346 with pytest .raises (TypeError ):
335347 dp_array .sort (axis = None )
336348
337- def test_sort_zero_dim (self ):
349+ def test_zero_dim (self ):
338350 np_array = numpy .array (2.5 )
339351 dp_array = dpnp .array (np_array )
340352
@@ -347,15 +359,20 @@ def test_sort_zero_dim(self):
347359 expected = numpy .sort (np_array , axis = None )
348360 assert_dtype_allclose (result , expected )
349361
350- def test_sort_notimplemented (self ):
362+ def test_error (self ):
351363 dp_array = dpnp .arange (10 )
352364
353- with pytest .raises (NotImplementedError ):
365+ # quicksort is currently not supported
366+ with pytest .raises (ValueError ):
354367 dpnp .sort (dp_array , kind = "quicksort" )
355368
356369 with pytest .raises (NotImplementedError ):
357370 dpnp .sort (dp_array , order = ["age" ])
358371
372+ # both kind and stable are given
373+ with pytest .raises (ValueError ):
374+ dpnp .sort (dp_array , kind = "mergesort" , stable = True )
375+
359376
360377class TestSortComplex :
361378 @pytest .mark .parametrize (
0 commit comments