Skip to content

Commit 935b1c8

Browse files
committed
MAINT: Address empty and all-zero input
in trim_zeros.
1 parent df69d78 commit 935b1c8

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

numpy/lib/function_base.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,46 +1624,45 @@ def trim_zeros(
16241624
"""
16251625
trim = trim.lower()
16261626

1627+
absolutes = np.abs(filt)
1628+
lengths = np.zeros((absolutes.ndim, 2), dtype=np.intp)
1629+
16271630
if axis is None:
16281631
# Apply iteratively to all axes
1629-
axis = range(filt.ndim)
1630-
1632+
axis = range(absolutes.ndim)
16311633
# Normalize axes to 1D-array
16321634
axis = np.asarray(axis, dtype=np.intp)
16331635
if axis.ndim == 0:
16341636
axis = np.asarray([axis], dtype=np.intp)
16351637

1636-
absolutes = np.abs(filt)
1637-
lengths = np.zeros((absolutes.ndim, 2), dtype=np.intp)
1638+
if atol > 0:
1639+
absolutes[absolutes <= atol] = 0
1640+
if rtol > 0:
1641+
absolutes[absolutes <= rtol * absolutes.max()] = 0
1642+
1643+
nonzero = np.nonzero(absolutes)
16381644

16391645
for current_axis in axis:
16401646
absolutes.take([], current_axis) # Raises if axis is out of bounds
16411647
if current_axis < 0:
16421648
current_axis += absolutes.ndim
16431649

1644-
# Reduce to envelope along all axes except the selected one
1645-
reduced = np.moveaxis(absolutes, current_axis, -1)
1646-
for _ in range(absolutes.ndim - 1):
1647-
reduced = reduced.max(axis=0)
1648-
assert reduced.ndim == 1
1649-
1650-
if atol > 0:
1651-
reduced[reduced <= atol] = 0
1652-
if rtol > 0:
1653-
reduced[reduced <= rtol * reduced.max()] = 0
1654-
1655-
# Find start and stop indices for current dimension
1656-
start, stop = np.nonzero(reduced)[0][[0, -1]]
1657-
stop += 1
1650+
if nonzero[current_axis].size > 0:
1651+
start = nonzero[current_axis].min()
1652+
stop = nonzero[current_axis].max()
1653+
stop += 1
1654+
else:
1655+
# In case the input is all-zero, slice only in front
1656+
start = stop = absolutes.shape[current_axis]
1657+
if "f" not in trim:
1658+
# except when only the backside is to be sliced
1659+
stop = 0
16581660

1661+
# Only slice on specified side(s)
16591662
if "f" not in trim:
16601663
start = None
1661-
else:
1662-
lengths[current_axis, 0] = start
16631664
if "b" not in trim:
16641665
stop = None
1665-
else:
1666-
lengths[current_axis, 1] = absolutes.shape[current_axis] - stop
16671666

16681667
# Use multi-dimensional slicing only when necessary, this allows
16691668
# preservation of the non-arrays input types
@@ -1673,6 +1672,11 @@ def trim_zeros(
16731672

16741673
filt = filt[sl]
16751674

1675+
if start is not None:
1676+
lengths[current_axis, 0] = start
1677+
if stop is not None:
1678+
lengths[current_axis, 1] = absolutes.shape[current_axis] - stop
1679+
16761680
if return_lengths is True:
16771681
return filt, lengths
16781682
else:

0 commit comments

Comments
 (0)