Skip to content

Commit fa9a0a7

Browse files
committed
MAINT: Make arg_trim_zeros output consistent
Ensure that the returned start and stop arrays of indices have exactly one dimension.
1 parent aa422b7 commit fa9a0a7

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

numpy/lib/function_base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,7 +1582,7 @@ def arg_trim_zeros(filt):
15821582
"""
15831583
nonzero = np.argwhere(filt)
15841584
if nonzero.size == 0:
1585-
start = stop = nonzero
1585+
start = stop = np.array([], dtype=np.intp)
15861586
else:
15871587
start = nonzero.min(axis=0)
15881588
stop = nonzero.max(axis=0)
@@ -1636,20 +1636,20 @@ def trim_zeros(filt, trim='fb', axis=-1):
16361636
[1, 2]
16371637
16381638
"""
1639-
start, stop = arg_trim_zeros(filt)
1639+
filt_ = np.asarray(filt)
1640+
start, stop = arg_trim_zeros(filt_)
16401641
stop += 1 # Adjust for slicing
1641-
ndim = start.shape[-1]
16421642

16431643
if start.size == 0:
16441644
# filt is all-zero -> assign same values to start and stop so that
16451645
# resulting slice will be empty
1646-
start = stop = np.zeros(ndim, dtype=np.intp)
1646+
start = stop = np.zeros(filt_.ndim, dtype=np.intp)
16471647
else:
16481648
trim = trim.lower()
16491649
if 'f' not in trim:
1650-
start = (None,) * ndim
1650+
start = (None,) * filt_.ndim
16511651
if 'b' not in trim:
1652-
stop = (None,) * ndim
1652+
stop = (None,) * filt_.ndim
16531653

16541654
if len(start) == 1:
16551655
# filt is 1D -> don't use multi-dimensional slicing to preserve
@@ -1660,7 +1660,7 @@ def trim_zeros(filt, trim='fb', axis=-1):
16601660
sl = tuple(slice(*x) for x in zip(start, stop))
16611661
else:
16621662
# only trim single axis
1663-
axis = normalize_axis_index(axis, ndim)
1663+
axis = normalize_axis_index(axis, filt_.ndim)
16641664
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)
16651665

16661666
return filt[sl]

0 commit comments

Comments
 (0)