Skip to content

Commit 5883cb4

Browse files
committed
Add nd-support to trim_zeros
1 parent 8f05542 commit 5883cb4

File tree

1 file changed

+66
-27
lines changed

1 file changed

+66
-27
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3900,25 +3900,40 @@ def transpose(a, axes=None):
39003900
permute_dims = transpose # permute_dims is an alias for transpose
39013901

39023902

3903-
def trim_zeros(filt, trim="fb"):
3903+
def trim_zeros(filt, trim="fb", axis=None):
39043904
"""
3905-
Trim the leading and/or trailing zeros from a 1-D array.
3905+
Remove values along a dimension which are zero along all other.
39063906
39073907
For full documentation refer to :obj:`numpy.trim_zeros`.
39083908
39093909
Parameters
39103910
----------
39113911
filt : {dpnp.ndarray, usm_ndarray}
3912-
Input 1-D array.
3913-
trim : str, optional
3914-
A string with 'f' representing trim from front and 'b' to trim from
3915-
back. By defaults, trim zeros from both front and back of the array.
3912+
Input array.
3913+
trim : {"fb", "f", "b"}, optional
3914+
A string with `"f"` representing trim from front and `"b"` to trim from
3915+
back. By default, zeros are trimmed on both sides. Front and back refer
3916+
to the edges of a dimension, with "front" referring to the side with
3917+
the lowest index 0, and "back" referring to the highest index
3918+
(or index -1).
39163919
Default: ``"fb"``.
3920+
axis : {None, int}, optional
3921+
If ``None``, `filt` is cropped such, that the smallest bounding box is
3922+
returned that still contains all values which are not zero.
3923+
If an `axis` is specified, `filt` will be sliced in that dimension only
3924+
on the sides specified by `trim`. The remaining area will be the
3925+
smallest that still contains all values which are not zero.
3926+
Default: ``None``.
39173927
39183928
Returns
39193929
-------
39203930
out : dpnp.ndarray
3921-
The result of trimming the input.
3931+
The result of trimming the input. The number of dimensions and the
3932+
input data type are preserved.
3933+
3934+
Notes
3935+
-----
3936+
For all-zero arrays, the first axis is trimmed first.
39223937
39233938
Examples
39243939
--------
@@ -3927,42 +3942,66 @@ def trim_zeros(filt, trim="fb"):
39273942
>>> np.trim_zeros(a)
39283943
array([1, 2, 3, 0, 2, 1])
39293944
3930-
>>> np.trim_zeros(a, 'b')
3945+
>>> np.trim_zeros(a, trim='b')
39313946
array([0, 0, 0, 1, 2, 3, 0, 2, 1])
39323947
3948+
Multiple dimensions are supported:
3949+
3950+
>>> b = np.array([[0, 0, 2, 3, 0, 0],
3951+
... [0, 1, 0, 3, 0, 0],
3952+
... [0, 0, 0, 0, 0, 0]])
3953+
>>> np.trim_zeros(b)
3954+
array([[0, 2, 3],
3955+
[1, 0, 3]])
3956+
3957+
>>> np.trim_zeros(b, axis=-1)
3958+
array([[0, 2, 3],
3959+
[1, 0, 3],
3960+
[0, 0, 0]])
3961+
39333962
"""
39343963

39353964
dpnp.check_supported_arrays_type(filt)
3936-
if filt.ndim == 0:
3937-
raise TypeError("0-d array cannot be trimmed")
3938-
if filt.ndim > 1:
3939-
raise ValueError("Multi-dimensional trim is not supported")
39403965

39413966
if not isinstance(trim, str):
39423967
raise TypeError("only string trim is supported")
39433968

3944-
trim = trim.upper()
3945-
if not any(x in trim for x in "FB"):
3946-
return filt # no trim rule is specified
3969+
trim = trim.lower()
3970+
if trim not in ["fb", "bf", "f", "b"]:
3971+
raise ValueError(f"unexpected character(s) in `trim`: {trim!r}")
3972+
3973+
nd = filt.ndim
3974+
if axis is not None:
3975+
axis = normalize_axis_index(axis, nd)
39473976

39483977
if filt.size == 0:
39493978
return filt # no trailing zeros in empty array
39503979

3951-
a = dpnp.nonzero(filt)[0]
3952-
a_size = a.size
3953-
if a_size == 0:
3954-
# 'filt' is array of zeros
3955-
return dpnp.empty_like(filt, shape=(0,))
3980+
non_zero = dpnp.argwhere(filt)
3981+
if non_zero.size == 0:
3982+
# `filt` has all zeros, so assign `start` and `stop` to the same value,
3983+
# then the resulting slice will be empty
3984+
start = stop = dpnp.zeros_like(filt, shape=nd, dtype=dpnp.intp)
3985+
else:
3986+
if "f" in trim:
3987+
start = non_zero.min(axis=0)
3988+
else:
3989+
start = (None,) * nd
39563990

3957-
first = 0
3958-
if "F" in trim:
3959-
first = a[0]
3991+
if "b" in trim:
3992+
stop = non_zero.max(axis=0)
3993+
stop += 1 # Adjust for slicing
3994+
else:
3995+
stop = (None,) * nd
39603996

3961-
last = filt.size
3962-
if "B" in trim:
3963-
last = a[-1] + 1
3997+
if axis is None:
3998+
# trim all axes
3999+
sl = tuple(slice(*x) for x in zip(start, stop))
4000+
else:
4001+
# only trim single axis
4002+
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)
39644003

3965-
return filt[first:last]
4004+
return filt[sl]
39664005

39674006

39684007
def unique(

0 commit comments

Comments
 (0)