@@ -1557,21 +1557,18 @@ def sort_complex(a):
1557
1557
return b
1558
1558
1559
1559
1560
- def _arg_trim_zeros (filt , trim = None ):
1560
+ def _arg_trim_zeros (filt ):
1561
1561
return (filt , filt )
1562
1562
1563
1563
1564
1564
@array_function_dispatch (_arg_trim_zeros )
1565
- def arg_trim_zeros (filt , trim = 'fb' ):
1565
+ def arg_trim_zeros (filt ):
1566
1566
"""Return indices of the first and last non-zero element.
1567
1567
1568
1568
Parameters
1569
1569
----------
1570
1570
filt : array_like
1571
1571
Input array.
1572
- trim : str, optional
1573
- A string with 'f' representing trim from front and 'b' to trim from
1574
- back. By default, zeros are trimmed from the front and back.
1575
1572
1576
1573
Returns
1577
1574
-------
@@ -1583,19 +1580,12 @@ def arg_trim_zeros(filt, trim='fb'):
1583
1580
--------
1584
1581
trim_zeros
1585
1582
"""
1586
- filt = np .asarray (filt )
1587
- trim = trim .lower ()
1588
-
1589
1583
nonzero = np .argwhere (filt )
1590
1584
if nonzero .size == 0 :
1591
- if trim .startswith ('b' ):
1592
- start = stop = np .zeros (filt .ndim , dtype = np .intp )
1593
- else :
1594
- start = stop = np .array (filt .shape , dtype = np .intp )
1585
+ start = stop = nonzero
1595
1586
else :
1596
1587
start = nonzero .min (axis = 0 )
1597
1588
stop = nonzero .max (axis = 0 )
1598
-
1599
1589
return start , stop
1600
1590
1601
1591
@@ -1646,18 +1636,31 @@ def trim_zeros(filt, trim='fb', axis=-1):
1646
1636
[1, 2]
1647
1637
1648
1638
"""
1649
- start , stop = arg_trim_zeros (filt , trim )
1639
+ start , stop = arg_trim_zeros (filt )
1650
1640
stop += 1 # Adjust for slicing
1641
+ ndim = start .shape [- 1 ]
1642
+
1643
+ if start .size == 0 :
1644
+ # filt is all-zero -> assign same values to start and stop so that
1645
+ # resulting slice will be empty
1646
+ start = stop = np .zeros (ndim , dtype = np .intp )
1647
+ else :
1648
+ trim = trim .lower ()
1649
+ if 'f' not in trim :
1650
+ start = (None ,) * ndim
1651
+ if 'b' not in trim :
1652
+ stop = (None ,) * ndim
1651
1653
1652
1654
if start .size == 1 :
1653
- # filt is 1D -> use multi-dimensional slicing only when necessary,
1654
- # this allows preservation of the non-array input types
1655
+ # filt is 1D -> don't use multi-dimensional slicing to preserve
1656
+ # non-array input types
1655
1657
sl = slice (start [0 ], stop [0 ])
1656
1658
elif axis is None :
1657
1659
# trim all axes
1658
1660
sl = tuple (slice (* x ) for x in zip (start , stop ))
1659
1661
else :
1660
- # only trim given axis
1662
+ # only trim single axis
1663
+ axis = normalize_axis_index (axis , ndim )
1661
1664
sl = (slice (None ),) * axis + (slice (start [axis ], stop [axis ]),) + (...,)
1662
1665
1663
1666
return filt [sl ]
0 commit comments