@@ -1582,7 +1582,7 @@ def arg_trim_zeros(filt):
1582
1582
"""
1583
1583
nonzero = np .argwhere (filt )
1584
1584
if nonzero .size == 0 :
1585
- start = stop = nonzero
1585
+ start = stop = np . array ([], dtype = np . intp )
1586
1586
else :
1587
1587
start = nonzero .min (axis = 0 )
1588
1588
stop = nonzero .max (axis = 0 )
@@ -1636,20 +1636,20 @@ def trim_zeros(filt, trim='fb', axis=-1):
1636
1636
[1, 2]
1637
1637
1638
1638
"""
1639
- start , stop = arg_trim_zeros (filt )
1639
+ filt_ = np .asarray (filt )
1640
+ start , stop = arg_trim_zeros (filt_ )
1640
1641
stop += 1 # Adjust for slicing
1641
- ndim = start .shape [- 1 ]
1642
1642
1643
1643
if start .size == 0 :
1644
1644
# filt is all-zero -> assign same values to start and stop so that
1645
1645
# resulting slice will be empty
1646
- start = stop = np .zeros (ndim , dtype = np .intp )
1646
+ start = stop = np .zeros (filt_ . ndim , dtype = np .intp )
1647
1647
else :
1648
1648
trim = trim .lower ()
1649
1649
if 'f' not in trim :
1650
- start = (None ,) * ndim
1650
+ start = (None ,) * filt_ . ndim
1651
1651
if 'b' not in trim :
1652
- stop = (None ,) * ndim
1652
+ stop = (None ,) * filt_ . ndim
1653
1653
1654
1654
if len (start ) == 1 :
1655
1655
# filt is 1D -> don't use multi-dimensional slicing to preserve
@@ -1660,7 +1660,7 @@ def trim_zeros(filt, trim='fb', axis=-1):
1660
1660
sl = tuple (slice (* x ) for x in zip (start , stop ))
1661
1661
else :
1662
1662
# only trim single axis
1663
- axis = normalize_axis_index (axis , ndim )
1663
+ axis = normalize_axis_index (axis , filt_ . ndim )
1664
1664
sl = (slice (None ),) * axis + (slice (start [axis ], stop [axis ]),) + (...,)
1665
1665
1666
1666
return filt [sl ]
0 commit comments