Skip to content

Commit df69d78

Browse files
committed
ENH: Add nd-support to trim_zeros
Add support for trimming nd-arrays with trim_zeros while preserving the old behavior for 1D input. The new parameter `axis` can specify a single dimension to be trimmed (reducing all other dimensions to the envelope of absolute values). If None or multiple values are specified, all or the selected dimensions are trimmed iteratively. This should make the function applicable to more use cases. Additionally provide the `atol`, `rtol` and `return_lengths` parameters. The first two control what is considered a "zero" to be trimmed, the latter provides the user with the on how much was trimmed.
1 parent 8bc83b5 commit df69d78

File tree

1 file changed

+91
-24
lines changed

1 file changed

+91
-24
lines changed

numpy/lib/function_base.py

Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,28 +1557,55 @@ def sort_complex(a):
15571557
return b
15581558

15591559

1560-
def _trim_zeros(filt, trim=None):
1560+
def _trim_zeros(
1561+
filt,
1562+
trim=None,
1563+
axis=None,
1564+
*,
1565+
atol=None,
1566+
rtol=None,
1567+
return_lengths=None
1568+
):
15611569
return (filt,)
15621570

15631571

15641572
@array_function_dispatch(_trim_zeros)
1565-
def trim_zeros(filt, trim='fb'):
1566-
"""
1567-
Trim the leading and/or trailing zeros from a 1-D array or sequence.
1573+
def trim_zeros(
1574+
filt,
1575+
trim='fb',
1576+
axis=-1,
1577+
*,
1578+
atol=0,
1579+
rtol=0,
1580+
return_lengths=False
1581+
):
1582+
"""Remove values along a dimension which are zero along all other.
15681583
15691584
Parameters
15701585
----------
1571-
filt : 1-D array or sequence
1586+
filt : array_like
15721587
Input array.
15731588
trim : str, optional
15741589
A string with 'f' representing trim from front and 'b' to trim from
1575-
back. Default is 'fb', trim zeros from both front and back of the
1576-
array.
1590+
back. By default, zeros are trimmed from the front and back.
1591+
axis : int or sequence, optional
1592+
The axis or a sequence of axes to trim. If None all axes are trimmed.
1593+
atol : float, optional
1594+
Absolute tolerance with which a value is considered for trimming.
1595+
rtol : float, optional
1596+
Relative tolerance with which a value is considered for trimming.
1597+
return_lengths : bool, optional
1598+
Additionally return the number of trimmed samples in each dimension at
1599+
the front and back.
15771600
15781601
Returns
15791602
-------
1580-
trimmed : 1-D array or sequence
1603+
trimmed : ndarray or sequence
15811604
The result of trimming the input. The input data type is preserved.
1605+
lengths : ndarray
1606+
If `return_lengths` was True, an array of shape (``filt.ndim``, 2) is
1607+
returned. It contains the number of trimmed samples in each dimension
1608+
at the front and back.
15821609
15831610
Examples
15841611
--------
@@ -1595,22 +1622,62 @@ def trim_zeros(filt, trim='fb'):
15951622
[1, 2]
15961623
15971624
"""
1598-
first = 0
1599-
trim = trim.upper()
1600-
if 'F' in trim:
1601-
for i in filt:
1602-
if i != 0.:
1603-
break
1604-
else:
1605-
first = first + 1
1606-
last = len(filt)
1607-
if 'B' in trim:
1608-
for i in filt[::-1]:
1609-
if i != 0.:
1610-
break
1611-
else:
1612-
last = last - 1
1613-
return filt[first:last]
1625+
trim = trim.lower()
1626+
1627+
if axis is None:
1628+
# Apply iteratively to all axes
1629+
axis = range(filt.ndim)
1630+
1631+
# Normalize axes to 1D-array
1632+
axis = np.asarray(axis, dtype=np.intp)
1633+
if axis.ndim == 0:
1634+
axis = np.asarray([axis], dtype=np.intp)
1635+
1636+
absolutes = np.abs(filt)
1637+
lengths = np.zeros((absolutes.ndim, 2), dtype=np.intp)
1638+
1639+
for current_axis in axis:
1640+
absolutes.take([], current_axis) # Raises if axis is out of bounds
1641+
if current_axis < 0:
1642+
current_axis += absolutes.ndim
1643+
1644+
# Reduce to envelope along all axes except the selected one
1645+
reduced = np.moveaxis(absolutes, current_axis, -1)
1646+
for _ in range(absolutes.ndim - 1):
1647+
reduced = reduced.max(axis=0)
1648+
assert reduced.ndim == 1
1649+
1650+
if atol > 0:
1651+
reduced[reduced <= atol] = 0
1652+
if rtol > 0:
1653+
reduced[reduced <= rtol * reduced.max()] = 0
1654+
1655+
# Find start and stop indices for current dimension
1656+
start, stop = np.nonzero(reduced)[0][[0, -1]]
1657+
stop += 1
1658+
1659+
if "f" not in trim:
1660+
start = None
1661+
else:
1662+
lengths[current_axis, 0] = start
1663+
if "b" not in trim:
1664+
stop = None
1665+
else:
1666+
lengths[current_axis, 1] = absolutes.shape[current_axis] - stop
1667+
1668+
# Use multi-dimensional slicing only when necessary, this allows
1669+
# preservation of the non-arrays input types
1670+
sl = slice(start, stop)
1671+
if current_axis != 0:
1672+
sl = (slice(None),) * current_axis + (sl,) + (...,)
1673+
1674+
filt = filt[sl]
1675+
1676+
if return_lengths is True:
1677+
return filt, lengths
1678+
else:
1679+
return filt
1680+
16141681

16151682
def _extract_dispatcher(condition, arr):
16161683
return (condition, arr)

0 commit comments

Comments
 (0)