diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 401b7c67a7f4..e588729c0066 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -3900,25 +3900,40 @@ def transpose(a, axes=None): permute_dims = transpose # permute_dims is an alias for transpose -def trim_zeros(filt, trim="fb"): +def trim_zeros(filt, trim="fb", axis=None): """ - Trim the leading and/or trailing zeros from a 1-D array. + Remove values along a dimension which are zero along all other. For full documentation refer to :obj:`numpy.trim_zeros`. Parameters ---------- filt : {dpnp.ndarray, usm_ndarray} - Input 1-D array. - trim : str, optional - A string with 'f' representing trim from front and 'b' to trim from - back. By defaults, trim zeros from both front and back of the array. + Input array. + trim : {"fb", "f", "b"}, optional + A string with `"f"` representing trim from front and `"b"` to trim from + back. By default, zeros are trimmed on both sides. Front and back refer + to the edges of a dimension, with "front" referring to the side with + the lowest index 0, and "back" referring to the highest index + (or index -1). Default: ``"fb"``. + axis : {None, int}, optional + If ``None``, `filt` is cropped such, that the smallest bounding box is + returned that still contains all values which are not zero. + If an `axis` is specified, `filt` will be sliced in that dimension only + on the sides specified by `trim`. The remaining area will be the + smallest that still contains all values which are not zero. + Default: ``None``. Returns ------- out : dpnp.ndarray - The result of trimming the input. + The result of trimming the input. The number of dimensions and the + input data type are preserved. + + Notes + ----- + For all-zero arrays, the first axis is trimmed first. Examples -------- @@ -3927,42 +3942,66 @@ def trim_zeros(filt, trim="fb"): >>> np.trim_zeros(a) array([1, 2, 3, 0, 2, 1]) - >>> np.trim_zeros(a, 'b') + >>> np.trim_zeros(a, trim='b') array([0, 0, 0, 1, 2, 3, 0, 2, 1]) + Multiple dimensions are supported: + + >>> b = np.array([[0, 0, 2, 3, 0, 0], + ... [0, 1, 0, 3, 0, 0], + ... [0, 0, 0, 0, 0, 0]]) + >>> np.trim_zeros(b) + array([[0, 2, 3], + [1, 0, 3]]) + + >>> np.trim_zeros(b, axis=-1) + array([[0, 2, 3], + [1, 0, 3], + [0, 0, 0]]) + """ dpnp.check_supported_arrays_type(filt) - if filt.ndim == 0: - raise TypeError("0-d array cannot be trimmed") - if filt.ndim > 1: - raise ValueError("Multi-dimensional trim is not supported") if not isinstance(trim, str): raise TypeError("only string trim is supported") - trim = trim.upper() - if not any(x in trim for x in "FB"): - return filt # no trim rule is specified + trim = trim.lower() + if trim not in ["fb", "bf", "f", "b"]: + raise ValueError(f"unexpected character(s) in `trim`: {trim!r}") + + nd = filt.ndim + if axis is not None: + axis = normalize_axis_index(axis, nd) if filt.size == 0: return filt # no trailing zeros in empty array - a = dpnp.nonzero(filt)[0] - a_size = a.size - if a_size == 0: - # 'filt' is array of zeros - return dpnp.empty_like(filt, shape=(0,)) + non_zero = dpnp.argwhere(filt) + if non_zero.size == 0: + # `filt` has all zeros, so assign `start` and `stop` to the same value, + # then the resulting slice will be empty + start = stop = dpnp.zeros_like(filt, shape=nd, dtype=dpnp.intp) + else: + if "f" in trim: + start = non_zero.min(axis=0) + else: + start = (None,) * nd - first = 0 - if "F" in trim: - first = a[0] + if "b" in trim: + stop = non_zero.max(axis=0) + stop += 1 # Adjust for slicing + else: + stop = (None,) * nd - last = filt.size - if "B" in trim: - last = a[-1] + 1 + if axis is None: + # trim all axes + sl = tuple(slice(*x) for x in zip(start, stop)) + else: + # only trim single axis + sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,) - return filt[first:last] + return filt[sl] def unique( diff --git a/dpnp/tests/test_manipulation.py b/dpnp/tests/test_manipulation.py index 97c951ad1562..f3a39069f0a0 100644 --- a/dpnp/tests/test_manipulation.py +++ b/dpnp/tests/test_manipulation.py @@ -1378,6 +1378,20 @@ def test_basic(self, dtype): expected = numpy.trim_zeros(a) assert_array_equal(result, expected) + @testing.with_requires("numpy>=2.2") + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("trim", ["F", "B", "fb"]) + @pytest.mark.parametrize("ndim", [0, 1, 2, 3]) + def test_basic_nd(self, dtype, trim, ndim): + a = numpy.ones((2,) * ndim, dtype=dtype) + a = numpy.pad(a, (2, 1), mode="constant", constant_values=0) + ia = dpnp.array(a) + + for axis in list(range(ndim)) + [None]: + result = dpnp.trim_zeros(ia, trim=trim, axis=axis) + expected = numpy.trim_zeros(a, trim=trim, axis=axis) + assert_array_equal(result, expected) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) @pytest.mark.parametrize("trim", ["F", "B"]) def test_trim(self, dtype, trim): @@ -1398,6 +1412,19 @@ def test_all_zero(self, dtype, trim): expected = numpy.trim_zeros(a, trim) assert_array_equal(result, expected) + @testing.with_requires("numpy>=2.2") + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + @pytest.mark.parametrize("trim", ["F", "B", "fb"]) + @pytest.mark.parametrize("ndim", [0, 1, 2, 3]) + def test_all_zero_nd(self, dtype, trim, ndim): + a = numpy.zeros((3,) * ndim, dtype=dtype) + ia = dpnp.array(a) + + for axis in list(range(ndim)) + [None]: + result = dpnp.trim_zeros(ia, trim=trim, axis=axis) + expected = numpy.trim_zeros(a, trim=trim, axis=axis) + assert_array_equal(result, expected) + def test_size_zero(self): a = numpy.zeros(0) ia = dpnp.array(a) @@ -1416,17 +1443,11 @@ def test_overflow(self, a): expected = numpy.trim_zeros(a) assert_array_equal(result, expected) - # TODO: modify once SAT-7616 - # numpy 2.2 validates trim rules - @testing.with_requires("numpy<2.2") - def test_trim_no_rule(self): - a = numpy.array([0, 0, 1, 0, 2, 3, 4, 0]) - ia = dpnp.array(a) - trim = "ADE" # no "F" or "B" in trim string - - result = dpnp.trim_zeros(ia, trim) - expected = numpy.trim_zeros(a, trim) - assert_array_equal(result, expected) + @testing.with_requires("numpy>=2.2") + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_trim_no_fb_in_rule(self, xp): + a = xp.array([0, 0, 1, 0, 2, 3, 4, 0]) + assert_raises(ValueError, xp.trim_zeros, a, "ADE") def test_list_array(self): assert_raises(TypeError, dpnp.trim_zeros, [0, 0, 1, 0, 2, 3, 4, 0]) diff --git a/dpnp/tests/third_party/cupy/manipulation_tests/test_add_remove.py b/dpnp/tests/third_party/cupy/manipulation_tests/test_add_remove.py index bcbb74806838..134001450ed5 100644 --- a/dpnp/tests/third_party/cupy/manipulation_tests/test_add_remove.py +++ b/dpnp/tests/third_party/cupy/manipulation_tests/test_add_remove.py @@ -387,8 +387,7 @@ def test_trim_back_zeros(self, xp, dtype): a = xp.array([1, 0, 2, 3, 0, 5, 0, 0, 0], dtype=dtype) return xp.trim_zeros(a, trim=self.trim) - # TODO: remove once SAT-7616 - @testing.with_requires("numpy<2.2") + @pytest.mark.skip("0-d array is supported") @testing.for_all_dtypes() def test_trim_zero_dim(self, dtype): for xp in (numpy, cupy): @@ -396,8 +395,7 @@ def test_trim_zero_dim(self, dtype): with pytest.raises(TypeError): xp.trim_zeros(a, trim=self.trim) - # TODO: remove once SAT-7616 - @testing.with_requires("numpy<2.2") + @pytest.mark.skip("nd array is supported") @testing.for_all_dtypes() def test_trim_ndim(self, dtype): for xp in (numpy, cupy):