|
4 | 4 | from numpy.testing import assert_array_equal, assert_equal, assert_raises |
5 | 5 |
|
6 | 6 | import dpnp |
| 7 | +from tests.third_party.cupy import testing |
7 | 8 |
|
8 | 9 | from .helper import ( |
9 | 10 | assert_dtype_allclose, |
@@ -61,14 +62,26 @@ def test_argsort_ndarray(self, dtype, axis): |
61 | 62 | expected = np_array.argsort(axis=axis) |
62 | 63 | assert_dtype_allclose(result, expected) |
63 | 64 |
|
64 | | - def test_argsort_stable(self): |
| 65 | + @pytest.mark.parametrize("kind", [None, "stable"]) |
| 66 | + def test_sort_kind(self, kind): |
65 | 67 | np_array = numpy.repeat(numpy.arange(10), 10) |
66 | 68 | dp_array = dpnp.array(np_array) |
67 | 69 |
|
68 | | - result = dpnp.argsort(dp_array, kind="stable") |
| 70 | + result = dpnp.argsort(dp_array, kind=kind) |
69 | 71 | expected = numpy.argsort(np_array, kind="stable") |
70 | 72 | assert_dtype_allclose(result, expected) |
71 | 73 |
|
| 74 | + # `stable` keyword is supported in numpy 2.0 and above |
| 75 | + @testing.with_requires("numpy>=2.0") |
| 76 | + @pytest.mark.parametrize("stable", [None, False, True]) |
| 77 | + def test_sort_stable(self, stable): |
| 78 | + np_array = numpy.repeat(numpy.arange(10), 10) |
| 79 | + dp_array = dpnp.array(np_array) |
| 80 | + |
| 81 | + result = dpnp.argsort(dp_array, stable="stable") |
| 82 | + expected = numpy.argsort(np_array, stable=True) |
| 83 | + assert_dtype_allclose(result, expected) |
| 84 | + |
72 | 85 | def test_argsort_zero_dim(self): |
73 | 86 | np_array = numpy.array(2.5) |
74 | 87 | dp_array = dpnp.array(np_array) |
@@ -295,14 +308,26 @@ def test_sort_ndarray(self, dtype, axis): |
295 | 308 | np_array.sort(axis=axis) |
296 | 309 | assert_dtype_allclose(dp_array, np_array) |
297 | 310 |
|
298 | | - def test_sort_stable(self): |
| 311 | + @pytest.mark.parametrize("kind", [None, "stable"]) |
| 312 | + def test_sort_kind(self, kind): |
299 | 313 | np_array = numpy.repeat(numpy.arange(10), 10) |
300 | 314 | dp_array = dpnp.array(np_array) |
301 | 315 |
|
302 | | - result = dpnp.sort(dp_array, kind="stable") |
| 316 | + result = dpnp.sort(dp_array, kind=kind) |
303 | 317 | expected = numpy.sort(np_array, kind="stable") |
304 | 318 | assert_dtype_allclose(result, expected) |
305 | 319 |
|
| 320 | + # `stable` keyword is supported in numpy 2.0 and above |
| 321 | + @testing.with_requires("numpy>=2.0") |
| 322 | + @pytest.mark.parametrize("stable", [None, False, True]) |
| 323 | + def test_sort_stable(self, stable): |
| 324 | + np_array = numpy.repeat(numpy.arange(10), 10) |
| 325 | + dp_array = dpnp.array(np_array) |
| 326 | + |
| 327 | + result = dpnp.sort(dp_array, stable="stable") |
| 328 | + expected = numpy.sort(np_array, stable=True) |
| 329 | + assert_dtype_allclose(result, expected) |
| 330 | + |
306 | 331 | def test_sort_ndarray_axis_none(self): |
307 | 332 | a = numpy.random.uniform(-10, 10, 12) |
308 | 333 | dp_array = dpnp.array(a).reshape(6, 2) |
|
0 commit comments