Skip to content

Commit 74d1184

Browse files
author
Vahid Tavanashad
committed
add new tests
1 parent a4b6818 commit 74d1184

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

tests/test_sort.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numpy.testing import assert_array_equal, assert_equal, assert_raises
55

66
import dpnp
7+
from tests.third_party.cupy import testing
78

89
from .helper import (
910
assert_dtype_allclose,
@@ -61,14 +62,26 @@ def test_argsort_ndarray(self, dtype, axis):
6162
expected = np_array.argsort(axis=axis)
6263
assert_dtype_allclose(result, expected)
6364

64-
def test_argsort_stable(self):
65+
@pytest.mark.parametrize("kind", [None, "stable"])
66+
def test_sort_kind(self, kind):
6567
np_array = numpy.repeat(numpy.arange(10), 10)
6668
dp_array = dpnp.array(np_array)
6769

68-
result = dpnp.argsort(dp_array, kind="stable")
70+
result = dpnp.argsort(dp_array, kind=kind)
6971
expected = numpy.argsort(np_array, kind="stable")
7072
assert_dtype_allclose(result, expected)
7173

74+
# `stable` keyword is supported in numpy 2.0 and above
75+
@testing.with_requires("numpy>=2.0")
76+
@pytest.mark.parametrize("stable", [None, False, True])
77+
def test_sort_stable(self, stable):
78+
np_array = numpy.repeat(numpy.arange(10), 10)
79+
dp_array = dpnp.array(np_array)
80+
81+
result = dpnp.argsort(dp_array, stable="stable")
82+
expected = numpy.argsort(np_array, stable=True)
83+
assert_dtype_allclose(result, expected)
84+
7285
def test_argsort_zero_dim(self):
7386
np_array = numpy.array(2.5)
7487
dp_array = dpnp.array(np_array)
@@ -295,14 +308,26 @@ def test_sort_ndarray(self, dtype, axis):
295308
np_array.sort(axis=axis)
296309
assert_dtype_allclose(dp_array, np_array)
297310

298-
def test_sort_stable(self):
311+
@pytest.mark.parametrize("kind", [None, "stable"])
312+
def test_sort_kind(self, kind):
299313
np_array = numpy.repeat(numpy.arange(10), 10)
300314
dp_array = dpnp.array(np_array)
301315

302-
result = dpnp.sort(dp_array, kind="stable")
316+
result = dpnp.sort(dp_array, kind=kind)
303317
expected = numpy.sort(np_array, kind="stable")
304318
assert_dtype_allclose(result, expected)
305319

320+
# `stable` keyword is supported in numpy 2.0 and above
321+
@testing.with_requires("numpy>=2.0")
322+
@pytest.mark.parametrize("stable", [None, False, True])
323+
def test_sort_stable(self, stable):
324+
np_array = numpy.repeat(numpy.arange(10), 10)
325+
dp_array = dpnp.array(np_array)
326+
327+
result = dpnp.sort(dp_array, stable="stable")
328+
expected = numpy.sort(np_array, stable=True)
329+
assert_dtype_allclose(result, expected)
330+
306331
def test_sort_ndarray_axis_none(self):
307332
a = numpy.random.uniform(-10, 10, 12)
308333
dp_array = dpnp.array(a).reshape(6, 2)

0 commit comments

Comments
 (0)