Skip to content

Commit a1d8b17

Browse files
authored
Merge pull request numpy#15181 from lagru/trim_zeros
ENH: Add nd-support to trim_zeros
2 parents 3126b97 + 177eceb commit a1d8b17

File tree

3 files changed

+138
-27
lines changed

3 files changed

+138
-27
lines changed

numpy/lib/_function_base_impl.py

Lines changed: 109 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,28 +1843,79 @@ def sort_complex(a):
18431843
return b
18441844

18451845

1846-
def _trim_zeros(filt, trim=None):
1846+
def _arg_trim_zeros(filt):
1847+
"""Return indices of the first and last non-zero element.
1848+
1849+
Parameters
1850+
----------
1851+
filt : array_like
1852+
Input array.
1853+
1854+
Returns
1855+
-------
1856+
start, stop : ndarray
1857+
Two arrays containing the indices of the first and last non-zero
1858+
element in each dimension.
1859+
1860+
See also
1861+
--------
1862+
trim_zeros
1863+
1864+
Examples
1865+
--------
1866+
>>> import numpy as np
1867+
>>> _arg_trim_zeros(np.array([0, 0, 1, 1, 0]))
1868+
(array([2]), array([3]))
1869+
"""
1870+
nonzero = (
1871+
np.argwhere(filt)
1872+
if filt.dtype != np.object_
1873+
# Historically, `trim_zeros` treats `None` in an object array
1874+
# as non-zero while argwhere doesn't, account for that
1875+
else np.argwhere(filt != 0)
1876+
)
1877+
if nonzero.size == 0:
1878+
start = stop = np.array([], dtype=np.intp)
1879+
else:
1880+
start = nonzero.min(axis=0)
1881+
stop = nonzero.max(axis=0)
1882+
return start, stop
1883+
1884+
1885+
def _trim_zeros(filt, trim=None, axis=None):
18471886
return (filt,)
18481887

18491888

18501889
@array_function_dispatch(_trim_zeros)
1851-
def trim_zeros(filt, trim='fb'):
1852-
"""
1853-
Trim the leading and/or trailing zeros from a 1-D array or sequence.
1890+
def trim_zeros(filt, trim='fb', axis=None):
1891+
"""Remove values along a dimension which are zero along all other.
18541892
18551893
Parameters
18561894
----------
1857-
filt : 1-D array or sequence
1895+
filt : array_like
18581896
Input array.
1859-
trim : str, optional
1897+
trim : {"fb", "f", "b"}, optional
18601898
A string with 'f' representing trim from front and 'b' to trim from
1861-
back. Default is 'fb', trim zeros from both front and back of the
1862-
array.
1899+
back. By default, zeros are trimmed on both sides.
1900+
Front and back refer to the edges of a dimension, with "front" refering
1901+
to the side with the lowest index 0, and "back" refering to the highest
1902+
index (or index -1).
1903+
axis : int or sequence, optional
1904+
If None, `filt` is cropped such, that the smallest bounding box is
1905+
returned that still contains all values which are not zero.
1906+
If an axis is specified, `filt` will be sliced in that dimension only
1907+
on the sides specified by `trim`. The remaining area will be the
1908+
smallest that still contains all values wich are not zero.
18631909
18641910
Returns
18651911
-------
1866-
trimmed : 1-D array or sequence
1867-
The result of trimming the input. The input data type is preserved.
1912+
trimmed : ndarray or sequence
1913+
The result of trimming the input. The number of dimensions and the
1914+
input data type are preserved.
1915+
1916+
Notes
1917+
-----
1918+
For all-zero arrays, the first axis is trimmed first.
18681919
18691920
Examples
18701921
--------
@@ -1873,32 +1924,63 @@ def trim_zeros(filt, trim='fb'):
18731924
>>> np.trim_zeros(a)
18741925
array([1, 2, 3, 0, 2, 1])
18751926
1876-
>>> np.trim_zeros(a, 'b')
1927+
>>> np.trim_zeros(a, trim='b')
18771928
array([0, 0, 0, ..., 0, 2, 1])
18781929
1930+
Multiple dimensions are supported.
1931+
1932+
>>> b = np.array([[0, 0, 2, 3, 0, 0],
1933+
... [0, 1, 0, 3, 0, 0],
1934+
... [0, 0, 0, 0, 0, 0]])
1935+
>>> np.trim_zeros(b)
1936+
array([[0, 2, 3],
1937+
[1, 0, 3]])
1938+
1939+
>>> np.trim_zeros(b, axis=-1)
1940+
array([[0, 2, 3],
1941+
[1, 0, 3],
1942+
[0, 0, 0]])
1943+
18791944
The input data type is preserved, list/tuple in means list/tuple out.
18801945
18811946
>>> np.trim_zeros([0, 1, 2, 0])
18821947
[1, 2]
18831948
18841949
"""
1950+
filt_ = np.asarray(filt)
1951+
1952+
trim = trim.lower()
1953+
if trim not in {"fb", "bf", "f", "b"}:
1954+
raise ValueError(f"unexpected character(s) in `trim`: {trim!r}")
1955+
1956+
start, stop = _arg_trim_zeros(filt_)
1957+
stop += 1 # Adjust for slicing
1958+
1959+
if start.size == 0:
1960+
# filt is all-zero -> assign same values to start and stop so that
1961+
# resulting slice will be empty
1962+
start = stop = np.zeros(filt_.ndim, dtype=np.intp)
1963+
else:
1964+
if 'f' not in trim:
1965+
start = (None,) * filt_.ndim
1966+
if 'b' not in trim:
1967+
stop = (None,) * filt_.ndim
1968+
1969+
if len(start) == 1:
1970+
# filt is 1D -> don't use multi-dimensional slicing to preserve
1971+
# non-array input types
1972+
sl = slice(start[0], stop[0])
1973+
elif axis is None:
1974+
# trim all axes
1975+
sl = tuple(slice(*x) for x in zip(start, stop))
1976+
else:
1977+
# only trim single axis
1978+
axis = normalize_axis_index(axis, filt_.ndim)
1979+
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)
1980+
1981+
trimmed = filt[sl]
1982+
return trimmed
18851983

1886-
first = 0
1887-
trim = trim.upper()
1888-
if 'F' in trim:
1889-
for i in filt:
1890-
if i != 0.:
1891-
break
1892-
else:
1893-
first = first + 1
1894-
last = len(filt)
1895-
if 'B' in trim:
1896-
for i in filt[::-1]:
1897-
if i != 0.:
1898-
break
1899-
else:
1900-
last = last - 1
1901-
return filt[first:last]
19021984

19031985

19041986
def _extract_dispatcher(condition, arr):

numpy/lib/_function_base_impl.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def sort_complex(a: ArrayLike) -> NDArray[complexfloating[Any, Any]]: ...
356356
def trim_zeros(
357357
filt: _TrimZerosSequence[_T],
358358
trim: L["f", "b", "fb", "bf"] = ...,
359+
axis: SupportsIndex = ...,
359360
) -> _T: ...
360361

361362
@overload

numpy/lib/tests/test_function_base.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,34 @@ def test_list_to_list(self):
13931393
res = trim_zeros(self.a.tolist())
13941394
assert isinstance(res, list)
13951395

1396+
@pytest.mark.parametrize("ndim", (0, 1, 2, 3, 10))
1397+
def test_nd_basic(self, ndim):
1398+
a = np.ones((2,) * ndim)
1399+
b = np.pad(a, (2, 1), mode="constant", constant_values=0)
1400+
res = trim_zeros(b, axis=None)
1401+
assert_array_equal(a, res)
1402+
1403+
@pytest.mark.parametrize("ndim", (0, 1, 2, 3))
1404+
def test_allzero(self, ndim):
1405+
a = np.zeros((3,) * ndim)
1406+
res = trim_zeros(a, axis=None)
1407+
assert_array_equal(res, np.zeros((0,) * ndim))
1408+
1409+
def test_trim_arg(self):
1410+
a = np.array([0, 1, 2, 0])
1411+
1412+
res = trim_zeros(a, trim='f')
1413+
assert_array_equal(res, [1, 2, 0])
1414+
1415+
res = trim_zeros(a, trim='b')
1416+
assert_array_equal(res, [0, 1, 2])
1417+
1418+
@pytest.mark.parametrize("trim", ("front", ""))
1419+
def test_unexpected_trim_value(self, trim):
1420+
arr = self.a
1421+
with pytest.raises(ValueError, match=r"unexpected character\(s\) in `trim`"):
1422+
trim_zeros(arr, trim=trim)
1423+
13961424

13971425
class TestExtins:
13981426

0 commit comments

Comments
 (0)