Skip to content

Commit b50d516

Browse files
committed
Add more tests to cover the use case
1 parent 23dbe91 commit b50d516

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

dpnp/tests/test_manipulation.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,17 +1852,28 @@ def test_equal_nan(self, eq_nan_kwd):
18521852

18531853
# TODO: uncomment once numpy 2.4.0 release is published
18541854
# @testing.with_requires("numpy>=2.4.0")
1855-
def test_1d_equal_nan_axis0(self):
1855+
@pytest.mark.parametrize("axis", [0, -1])
1856+
def test_1d_equal_nan_axis(self, axis):
18561857
a = numpy.array([numpy.nan, 0, 0, numpy.nan])
18571858
ia = dpnp.array(a)
18581859

1859-
result = dpnp.unique(ia, axis=0, equal_nan=True)
1860-
expected = numpy.unique(a, axis=0, equal_nan=True)
1860+
result = dpnp.unique(ia, axis=axis, equal_nan=True)
1861+
expected = numpy.unique(a, axis=axis, equal_nan=True)
18611862
# TODO: remove when numpy#29372 is released
18621863
if numpy_version() < "2.4.0":
18631864
expected = numpy.array([0.0, numpy.nan])
18641865
assert_array_equal(result, expected)
18651866

1867+
# TODO: uncomment once numpy 2.4.0 release is published
1868+
# @testing.with_requires("numpy>=2.4.0")
1869+
@pytest.mark.parametrize("equal_nan", [True, False])
1870+
# @pytest.mark.parametrize("xp", [numpy, dpnp])
1871+
@pytest.mark.parametrize("xp", [dpnp])
1872+
def test_1d_axis_float_raises_typeerror(self, xp, equal_nan):
1873+
a = xp.array([xp.nan, 0, 0, xp.nan])
1874+
with pytest.raises(TypeError, match="integer argument expected"):
1875+
xp.unique(a, axis=0.0, equal_nan=equal_nan)
1876+
18661877
@testing.with_requires("numpy>=2.0.1")
18671878
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
18681879
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)