@@ -3900,25 +3900,40 @@ def transpose(a, axes=None):
39003900permute_dims = transpose # permute_dims is an alias for transpose
39013901
39023902
3903- def trim_zeros (filt , trim = "fb" ):
3903+ def trim_zeros (filt , trim = "fb" , axis = None ):
39043904 """
3905- Trim the leading and/or trailing zeros from a 1-D array .
3905+ Remove values along a dimension which are zero along all other .
39063906
39073907 For full documentation refer to :obj:`numpy.trim_zeros`.
39083908
39093909 Parameters
39103910 ----------
39113911 filt : {dpnp.ndarray, usm_ndarray}
3912- Input 1-D array.
3913- trim : str, optional
3914- A string with 'f' representing trim from front and 'b' to trim from
3915- back. By defaults, trim zeros from both front and back of the array.
3912+ Input array.
3913+ trim : {"fb", "f", "b"}, optional
3914+ A string with `"f"` representing trim from front and `"b"` to trim from
3915+ back. By default, zeros are trimmed on both sides. Front and back refer
3916+ to the edges of a dimension, with "front" referring to the side with
3917+ the lowest index 0, and "back" referring to the highest index
3918+ (or index -1).
39163919 Default: ``"fb"``.
3920+ axis : {None, int}, optional
3921+ If ``None``, `filt` is cropped such, that the smallest bounding box is
3922+ returned that still contains all values which are not zero.
3923+ If an `axis` is specified, `filt` will be sliced in that dimension only
3924+ on the sides specified by `trim`. The remaining area will be the
3925+ smallest that still contains all values which are not zero.
3926+ Default: ``None``.
39173927
39183928 Returns
39193929 -------
39203930 out : dpnp.ndarray
3921- The result of trimming the input.
3931+ The result of trimming the input. The number of dimensions and the
3932+ input data type are preserved.
3933+
3934+ Notes
3935+ -----
3936+ For all-zero arrays, the first axis is trimmed first.
39223937
39233938 Examples
39243939 --------
@@ -3927,42 +3942,66 @@ def trim_zeros(filt, trim="fb"):
39273942 >>> np.trim_zeros(a)
39283943 array([1, 2, 3, 0, 2, 1])
39293944
3930- >>> np.trim_zeros(a, 'b')
3945+ >>> np.trim_zeros(a, trim= 'b')
39313946 array([0, 0, 0, 1, 2, 3, 0, 2, 1])
39323947
3948+ Multiple dimensions are supported:
3949+
3950+ >>> b = np.array([[0, 0, 2, 3, 0, 0],
3951+ ... [0, 1, 0, 3, 0, 0],
3952+ ... [0, 0, 0, 0, 0, 0]])
3953+ >>> np.trim_zeros(b)
3954+ array([[0, 2, 3],
3955+ [1, 0, 3]])
3956+
3957+ >>> np.trim_zeros(b, axis=-1)
3958+ array([[0, 2, 3],
3959+ [1, 0, 3],
3960+ [0, 0, 0]])
3961+
39333962 """
39343963
39353964 dpnp .check_supported_arrays_type (filt )
3936- if filt .ndim == 0 :
3937- raise TypeError ("0-d array cannot be trimmed" )
3938- if filt .ndim > 1 :
3939- raise ValueError ("Multi-dimensional trim is not supported" )
39403965
39413966 if not isinstance (trim , str ):
39423967 raise TypeError ("only string trim is supported" )
39433968
3944- trim = trim .upper ()
3945- if not any (x in trim for x in "FB" ):
3946- return filt # no trim rule is specified
3969+ trim = trim .lower ()
3970+ if trim not in ["fb" , "bf" , "f" , "b" ]:
3971+ raise ValueError (f"unexpected character(s) in `trim`: { trim !r} " )
3972+
3973+ nd = filt .ndim
3974+ if axis is not None :
3975+ axis = normalize_axis_index (axis , nd )
39473976
39483977 if filt .size == 0 :
39493978 return filt # no trailing zeros in empty array
39503979
3951- a = dpnp .nonzero (filt )[0 ]
3952- a_size = a .size
3953- if a_size == 0 :
3954- # 'filt' is array of zeros
3955- return dpnp .empty_like (filt , shape = (0 ,))
3980+ non_zero = dpnp .argwhere (filt )
3981+ if non_zero .size == 0 :
3982+ # `filt` has all zeros, so assign `start` and `stop` to the same value,
3983+ # then the resulting slice will be empty
3984+ start = stop = dpnp .zeros_like (filt , shape = nd , dtype = dpnp .intp )
3985+ else :
3986+ if "f" in trim :
3987+ start = non_zero .min (axis = 0 )
3988+ else :
3989+ start = (None ,) * nd
39563990
3957- first = 0
3958- if "F" in trim :
3959- first = a [0 ]
3991+ if "b" in trim :
3992+ stop = non_zero .max (axis = 0 )
3993+ stop += 1 # Adjust for slicing
3994+ else :
3995+ stop = (None ,) * nd
39603996
3961- last = filt .size
3962- if "B" in trim :
3963- last = a [- 1 ] + 1
3997+ if axis is None :
3998+ # trim all axes
3999+ sl = tuple (slice (* x ) for x in zip (start , stop ))
4000+ else :
4001+ # only trim single axis
4002+ sl = (slice (None ),) * axis + (slice (start [axis ], stop [axis ]),) + (...,)
39644003
3965- return filt [first : last ]
4004+ return filt [sl ]
39664005
39674006
39684007def unique (
0 commit comments