@@ -153,81 +153,117 @@ def test_sort_validation():
153153 dpt .sort (dict ())
154154
155155
156+ def test_sort_validation_kind ():
157+ get_queue_or_skip ()
158+
159+ x = dpt .ones (128 , dtype = "u1" )
160+
161+ with pytest .raises (ValueError ):
162+ dpt .sort (x , kind = Ellipsis )
163+
164+ with pytest .raises (ValueError ):
165+ dpt .sort (x , kind = "invalid" )
166+
167+
156168def test_argsort_validation ():
157169 with pytest .raises (TypeError ):
158170 dpt .argsort (dict ())
159171
160172
161- def test_sort_axis0 ():
173+ def test_argsort_validation_kind ():
174+ get_queue_or_skip ()
175+
176+ x = dpt .arange (127 , stop = 0 , step = - 1 , dtype = "i1" )
177+
178+ with pytest .raises (ValueError ):
179+ dpt .argsort (x , kind = Ellipsis )
180+
181+ with pytest .raises (ValueError ):
182+ dpt .argsort (x , kind = "invalid" )
183+
184+
185+ _all_kinds = ["stable" , "mergesort" , "radixsort" ]
186+
187+
188+ @pytest .mark .parametrize ("kind" , _all_kinds )
189+ def test_sort_axis0 (kind ):
162190 get_queue_or_skip ()
163191
164192 n , m = 200 , 30
165193 xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
166194 x = dpt .reshape (xf , (n , m ))
167- s = dpt .sort (x , axis = 0 )
195+ s = dpt .sort (x , axis = 0 , kind = kind )
168196
169197 assert dpt .all (s [:- 1 , :] <= s [1 :, :])
170198
171199
172- def test_argsort_axis0 ():
200+ @pytest .mark .parametrize ("kind" , _all_kinds )
201+ def test_argsort_axis0 (kind ):
173202 get_queue_or_skip ()
174203
175204 n , m = 200 , 30
176205 xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
177206 x = dpt .reshape (xf , (n , m ))
178- idx = dpt .argsort (x , axis = 0 )
207+ idx = dpt .argsort (x , axis = 0 , kind = kind )
179208
180209 s = dpt .take_along_axis (x , idx , axis = 0 )
181210
182211 assert dpt .all (s [:- 1 , :] <= s [1 :, :])
183212
184213
185- def test_argsort_axis1 ():
214+ @pytest .mark .parametrize ("kind" , _all_kinds )
215+ def test_argsort_axis1 (kind ):
186216 get_queue_or_skip ()
187217
188218 n , m = 200 , 30
189219 xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
190220 x = dpt .reshape (xf , (n , m ))
191- idx = dpt .argsort (x , axis = 1 )
221+ idx = dpt .argsort (x , axis = 1 , kind = kind )
192222
193223 s = dpt .take_along_axis (x , idx , axis = 1 )
194224
195225 assert dpt .all (s [:, :- 1 ] <= s [:, 1 :])
196226
197227
198- def test_sort_strided ():
228+ @pytest .mark .parametrize ("kind" , _all_kinds )
229+ def test_sort_strided (kind ):
199230 get_queue_or_skip ()
200231
201232 x_orig = dpt .arange (100 , dtype = "i4" )
202233 x_flipped = dpt .flip (x_orig , axis = 0 )
203- s = dpt .sort (x_flipped )
234+ s = dpt .sort (x_flipped , kind = kind )
204235
205236 assert dpt .all (s == x_orig )
206237
207238
208- def test_argsort_strided ():
239+ @pytest .mark .parametrize ("kind" , _all_kinds )
240+ def test_argsort_strided (kind ):
209241 get_queue_or_skip ()
210242
211243 x_orig = dpt .arange (100 , dtype = "i4" )
212244 x_flipped = dpt .flip (x_orig , axis = 0 )
213- idx = dpt .argsort (x_flipped )
245+ idx = dpt .argsort (x_flipped , kind = kind )
214246 s = dpt .take_along_axis (x_flipped , idx , axis = 0 )
215247
216248 assert dpt .all (s == x_orig )
217249
218250
219- def test_sort_0d_array ():
251+ @pytest .mark .parametrize ("kind" , _all_kinds )
252+ def test_sort_0d_array (kind ):
220253 get_queue_or_skip ()
221254
222255 x = dpt .asarray (1 , dtype = "i4" )
223- assert dpt .sort (x ) == 1
256+ expected = dpt .asarray (1 , dtype = "i4" )
257+ assert dpt .sort (x , kind = kind ) == expected
224258
225259
226- def test_argsort_0d_array ():
260+ @pytest .mark .parametrize ("kind" , _all_kinds )
261+ def test_argsort_0d_array (kind ):
227262 get_queue_or_skip ()
228263
229264 x = dpt .asarray (1 , dtype = "i4" )
230- assert dpt .argsort (x ) == 0
265+ expected = dpt .asarray (0 , dtype = "i4" )
266+ assert dpt .argsort (x , kind = kind ) == expected
231267
232268
233269@pytest .mark .parametrize (
@@ -238,22 +274,23 @@ def test_argsort_0d_array():
238274 "f8" ,
239275 ],
240276)
241- def test_sort_real_fp_nan (dtype ):
277+ @pytest .mark .parametrize ("kind" , _all_kinds )
278+ def test_sort_real_fp_nan (dtype , kind ):
242279 q = get_queue_or_skip ()
243280 skip_if_dtype_not_supported (dtype , q )
244281
245282 x = dpt .asarray (
246283 [- 0.0 , 0.1 , dpt .nan , 0.0 , - 0.1 , dpt .nan , 0.2 , - 0.3 ], dtype = dtype
247284 )
248- s = dpt .sort (x )
285+ s = dpt .sort (x , kind = kind )
249286
250287 expected = dpt .asarray (
251288 [- 0.3 , - 0.1 , - 0.0 , 0.0 , 0.1 , 0.2 , dpt .nan , dpt .nan ], dtype = dtype
252289 )
253290
254291 assert dpt .allclose (s , expected , equal_nan = True )
255292
256- s = dpt .sort (x , descending = True )
293+ s = dpt .sort (x , descending = True , kind = kind )
257294
258295 expected = dpt .asarray (
259296 [dpt .nan , dpt .nan , 0.2 , 0.1 , - 0.0 , 0.0 , - 0.1 , - 0.3 ], dtype = dtype
0 commit comments