Skip to content

Commit 53ebda1

Browse files
committed
ENH: Support trimming nd-arrays that are all-zero
1 parent c75168b commit 53ebda1

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

numpy/lib/function_base.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,21 +1557,18 @@ def sort_complex(a):
15571557
return b
15581558

15591559

1560-
def _arg_trim_zeros(filt, trim=None):
1560+
def _arg_trim_zeros(filt):
15611561
return (filt, filt)
15621562

15631563

15641564
@array_function_dispatch(_arg_trim_zeros)
1565-
def arg_trim_zeros(filt, trim='fb'):
1565+
def arg_trim_zeros(filt):
15661566
"""Return indices of the first and last non-zero element.
15671567
15681568
Parameters
15691569
----------
15701570
filt : array_like
15711571
Input array.
1572-
trim : str, optional
1573-
A string with 'f' representing trim from front and 'b' to trim from
1574-
back. By default, zeros are trimmed from the front and back.
15751572
15761573
Returns
15771574
-------
@@ -1583,19 +1580,12 @@ def arg_trim_zeros(filt, trim='fb'):
15831580
--------
15841581
trim_zeros
15851582
"""
1586-
filt = np.asarray(filt)
1587-
trim = trim.lower()
1588-
15891583
nonzero = np.argwhere(filt)
15901584
if nonzero.size == 0:
1591-
if trim.startswith('b'):
1592-
start = stop = np.zeros(filt.ndim, dtype=np.intp)
1593-
else:
1594-
start = stop = np.array(filt.shape, dtype=np.intp)
1585+
start = stop = nonzero
15951586
else:
15961587
start = nonzero.min(axis=0)
15971588
stop = nonzero.max(axis=0)
1598-
15991589
return start, stop
16001590

16011591

@@ -1646,18 +1636,31 @@ def trim_zeros(filt, trim='fb', axis=-1):
16461636
[1, 2]
16471637
16481638
"""
1649-
start, stop = arg_trim_zeros(filt, trim)
1639+
start, stop = arg_trim_zeros(filt)
16501640
stop += 1 # Adjust for slicing
1641+
ndim = start.shape[-1]
1642+
1643+
if start.size == 0:
1644+
# filt is all-zero -> assign same values to start and stop so that
1645+
# resulting slice will be empty
1646+
start = stop = np.zeros(ndim, dtype=np.intp)
1647+
else:
1648+
trim = trim.lower()
1649+
if 'f' not in trim:
1650+
start = (None,) * ndim
1651+
if 'b' not in trim:
1652+
stop = (None,) * ndim
16511653

16521654
if start.size == 1:
1653-
# filt is 1D -> use multi-dimensional slicing only when necessary,
1654-
# this allows preservation of the non-array input types
1655+
# filt is 1D -> don't use multi-dimensional slicing to preserve
1656+
# non-array input types
16551657
sl = slice(start[0], stop[0])
16561658
elif axis is None:
16571659
# trim all axes
16581660
sl = tuple(slice(*x) for x in zip(start, stop))
16591661
else:
1660-
# only trim given axis
1662+
# only trim single axis
1663+
axis = normalize_axis_index(axis, ndim)
16611664
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)
16621665

16631666
return filt[sl]

numpy/lib/tests/test_function_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,12 @@ def test_nd_basic(self, ndim):
11401140
res = trim_zeros(b, axis=None)
11411141
assert_array_equal(a, res)
11421142

1143+
@pytest.mark.parametrize("ndim", (0, 1, 2, 3))
1144+
def test_allzero(self, ndim):
1145+
a = np.zeros((3,) * ndim)
1146+
res = trim_zeros(a, axis=None)
1147+
assert_array_equal(res, np.zeros((0,) * ndim))
1148+
11431149

11441150
class TestExtins(object):
11451151

0 commit comments

Comments
 (0)