Skip to content

Commit c75168b

Browse files
committed
ENH: Add arg_trim_zeros
as it's own function. trim_zeros uses its output to newly support the nd-case.
1 parent 5d6b9d1 commit c75168b

File tree

1 file changed

+65
-61
lines changed

1 file changed

+65
-61
lines changed

numpy/lib/function_base.py

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,12 +1557,54 @@ def sort_complex(a):
15571557
return b
15581558

15591559

1560-
def _trim_zeros(filt, trim=None, axis=None, *, return_lengths=None):
1560+
def _arg_trim_zeros(filt, trim=None):
1561+
return (filt, filt)
1562+
1563+
1564+
@array_function_dispatch(_arg_trim_zeros)
1565+
def arg_trim_zeros(filt, trim='fb'):
1566+
"""Return indices of the first and last non-zero element.
1567+
1568+
Parameters
1569+
----------
1570+
filt : array_like
1571+
Input array.
1572+
trim : str, optional
1573+
A string with 'f' representing trim from front and 'b' to trim from
1574+
back. By default, zeros are trimmed from the front and back.
1575+
1576+
Returns
1577+
-------
1578+
start, stop : ndarray
1579+
Two arrays containing the indices of the first and last non-zero
1580+
element in each dimension.
1581+
1582+
See also
1583+
--------
1584+
trim_zeros
1585+
"""
1586+
filt = np.asarray(filt)
1587+
trim = trim.lower()
1588+
1589+
nonzero = np.argwhere(filt)
1590+
if nonzero.size == 0:
1591+
if trim.startswith('b'):
1592+
start = stop = np.zeros(filt.ndim, dtype=np.intp)
1593+
else:
1594+
start = stop = np.array(filt.shape, dtype=np.intp)
1595+
else:
1596+
start = nonzero.min(axis=0)
1597+
stop = nonzero.max(axis=0)
1598+
1599+
return start, stop
1600+
1601+
1602+
def _trim_zeros(filt, trim=None, axis=None):
15611603
return (filt,)
15621604

15631605

15641606
@array_function_dispatch(_trim_zeros)
1565-
def trim_zeros(filt, trim='fb', axis=-1, *, return_lengths=False):
1607+
def trim_zeros(filt, trim='fb', axis=-1):
15661608
"""Remove values along a dimension which are zero along all other.
15671609
15681610
Parameters
@@ -1573,23 +1615,21 @@ def trim_zeros(filt, trim='fb', axis=-1, *, return_lengths=False):
15731615
A string with 'f' representing trim from front and 'b' to trim from
15741616
back. By default, zeros are trimmed from the front and back.
15751617
axis : int or sequence, optional
1576-
The axis or a sequence of axes to trim. If None all axes are trimmed.
1577-
return_lengths : bool, optional
1578-
Additionally return the number of trimmed samples in each dimension at
1579-
the front and back.
1618+
The axis to trim. If None all axes are trimmed.
15801619
15811620
Returns
15821621
-------
15831622
trimmed : ndarray or sequence
15841623
The result of trimming the input. The input data type is preserved.
1585-
lengths : ndarray
1586-
If `return_lengths` was True, an array of shape (``filt.ndim``, 2) is
1587-
returned. It contains the number of trimmed samples in each dimension
1588-
at the front and back.
1624+
1625+
See also
1626+
--------
1627+
arg_trim_zeros
15891628
15901629
Notes
15911630
-----
1592-
For all-zero arrays, the first axis is trimmed first.
1631+
For all-zero arrays, the first axis is trimmed depending on the order in
1632+
`trim`.
15931633
15941634
Examples
15951635
--------
@@ -1606,57 +1646,21 @@ def trim_zeros(filt, trim='fb', axis=-1, *, return_lengths=False):
16061646
[1, 2]
16071647
16081648
"""
1609-
trim = trim.lower()
1610-
1611-
absolutes = np.abs(filt)
1612-
nonzero = np.nonzero(absolutes)
1613-
lengths = np.zeros((absolutes.ndim, 2), dtype=np.intp)
1614-
1615-
if axis is None:
1616-
# Apply iteratively to all axes
1617-
axis = range(absolutes.ndim)
1618-
# Normalize axes to 1D-array
1619-
axis = np.asarray(axis, dtype=np.intp)
1620-
if axis.ndim == 0:
1621-
axis = np.asarray([axis], dtype=np.intp)
1622-
1623-
for current_axis in axis:
1624-
current_axis = normalize_axis_index(current_axis, absolutes.ndim)
1625-
1626-
if nonzero[current_axis].size > 0:
1627-
start = nonzero[current_axis].min()
1628-
stop = nonzero[current_axis].max() + 1
1629-
else:
1630-
# In case the input is all-zero, slice depending on preference
1631-
# given by user
1632-
if trim.startswith("b"):
1633-
start = stop = 0
1634-
else:
1635-
start = stop = absolutes.shape[current_axis]
1636-
1637-
# Only slice on specified side(s)
1638-
if "f" not in trim:
1639-
start = None
1640-
if "b" not in trim:
1641-
stop = None
1642-
1643-
# Use multi-dimensional slicing only when necessary, this allows
1644-
# preservation of the non-array input types
1645-
sl = slice(start, stop)
1646-
if current_axis != 0:
1647-
sl = (slice(None),) * current_axis + (sl,) + (...,)
1648-
1649-
filt = filt[sl]
1650-
1651-
if start is not None:
1652-
lengths[current_axis, 0] = start
1653-
if stop is not None:
1654-
lengths[current_axis, 1] = absolutes.shape[current_axis] - stop
1655-
1656-
if return_lengths is True:
1657-
return filt, lengths
1649+
start, stop = arg_trim_zeros(filt, trim)
1650+
stop += 1 # Adjust for slicing
1651+
1652+
if start.size == 1:
1653+
# filt is 1D -> use multi-dimensional slicing only when necessary,
1654+
# this allows preservation of the non-array input types
1655+
sl = slice(start[0], stop[0])
1656+
elif axis is None:
1657+
# trim all axes
1658+
sl = tuple(slice(*x) for x in zip(start, stop))
16581659
else:
1659-
return filt
1660+
# only trim given axis
1661+
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)
1662+
1663+
return filt[sl]
16601664

16611665

16621666
def _extract_dispatcher(condition, arr):

0 commit comments

Comments
 (0)