Skip to content

Commit 12c324b

Browse files
authored
Follow up the fix implemented in gh-2530 to fix dpnp.unique with axis=0 and 1d input (#2587)
The PR improves the fix implemented in `dpnp.unique` for axis=0 and 1d input. Also it adds an explicit validation in case of floating axis passed.
1 parent b27a7ca commit 12c324b

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5858
* Updated `pre-commit` GitHub workflow to pass `no-commit-to-branch` check [#2501](https://github.com/IntelPython/dpnp/pull/2501)
5959
* Updated the math formulas in summary of `dpnp.matvec` and `dpnp.vecmat` to correct a typo [#2503](https://github.com/IntelPython/dpnp/pull/2503)
6060
* Avoided negating unsigned integers in ceil division used in `dpnp.resize` implementation [#2508](https://github.com/IntelPython/dpnp/pull/2508)
61-
* Fixed `dpnp.unique` with 1d input array and `axis=0`, `equal_nan=True` keywords passed where the produced result doesn't collapse the NaNs [#2530](https://github.com/IntelPython/dpnp/pull/2530)
61+
* Fixed `dpnp.unique` with 1d input array and `axis=0`, `equal_nan=True` keywords passed where the produced result doesn't collapse the NaNs [#2530](https://github.com/IntelPython/dpnp/pull/2530), [#2587](https://github.com/IntelPython/dpnp/pull/2587)
6262
* Resolved issue when `dpnp.ndarray` constructor is called with `dpnp.ndarray.data` as `buffer` keyword [#2533](https://github.com/IntelPython/dpnp/pull/2533)
6363
* Fixed `dpnp.linalg.cond` to always return a real dtype [#2547](https://github.com/IntelPython/dpnp/pull/2547)
6464
* Resolved the issue in `dpnp.random` functions to allow any value of `size` where each element is castable to `Py_ssize_t` type [#2578](https://github.com/IntelPython/dpnp/pull/2578)

dpnp/dpnp_iface_manipulation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4270,7 +4270,13 @@ def unique(
42704270
42714271
"""
42724272

4273-
if axis is None or (axis == 0 and ar.ndim == 1):
4273+
dpnp.check_supported_arrays_type(ar)
4274+
nd = ar.ndim
4275+
4276+
if axis is None or nd == 1:
4277+
if axis is not None:
4278+
normalize_axis_index(axis, nd)
4279+
42744280
return _unique_1d(
42754281
ar, return_index, return_inverse, return_counts, equal_nan
42764282
)
@@ -4280,7 +4286,7 @@ def unique(
42804286
ar = dpnp.moveaxis(ar, axis, 0)
42814287
except AxisError:
42824288
# this removes the "axis1" or "axis2" prefix from the error message
4283-
raise AxisError(axis, ar.ndim) from None
4289+
raise AxisError(axis, nd) from None
42844290

42854291
# reshape input array into a contiguous 2D array
42864292
orig_sh = ar.shape

dpnp/tests/test_manipulation.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,17 +1852,28 @@ def test_equal_nan(self, eq_nan_kwd):
18521852

18531853
# TODO: uncomment once numpy 2.4.0 release is published
18541854
# @testing.with_requires("numpy>=2.4.0")
1855-
def test_1d_equal_nan_axis0(self):
1855+
@pytest.mark.parametrize("axis", [0, -1])
1856+
def test_1d_equal_nan_axis(self, axis):
18561857
a = numpy.array([numpy.nan, 0, 0, numpy.nan])
18571858
ia = dpnp.array(a)
18581859

1859-
result = dpnp.unique(ia, axis=0, equal_nan=True)
1860-
expected = numpy.unique(a, axis=0, equal_nan=True)
1860+
result = dpnp.unique(ia, axis=axis, equal_nan=True)
1861+
expected = numpy.unique(a, axis=axis, equal_nan=True)
18611862
# TODO: remove when numpy#29372 is released
18621863
if numpy_version() < "2.4.0":
18631864
expected = numpy.array([0.0, numpy.nan])
18641865
assert_array_equal(result, expected)
18651866

1867+
# TODO: uncomment once numpy 2.4.0 release is published
1868+
# @testing.with_requires("numpy>=2.4.0")
1869+
@pytest.mark.parametrize("equal_nan", [True, False])
1870+
# @pytest.mark.parametrize("xp", [numpy, dpnp])
1871+
@pytest.mark.parametrize("xp", [dpnp])
1872+
def test_1d_axis_float_raises_typeerror(self, xp, equal_nan):
1873+
a = xp.array([xp.nan, 0, 0, xp.nan])
1874+
with pytest.raises(TypeError, match="integer argument expected"):
1875+
xp.unique(a, axis=0.0, equal_nan=equal_nan)
1876+
18661877
@testing.with_requires("numpy>=2.0.1")
18671878
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
18681879
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)