@@ -1607,6 +1607,10 @@ def trim_zeros(
1607
1607
returned. It contains the number of trimmed samples in each dimension
1608
1608
at the front and back.
1609
1609
1610
+ Notes
1611
+ -----
1612
+ For all-zero arrays, the first axis is trimmed first.
1613
+
1610
1614
Examples
1611
1615
--------
1612
1616
>>> a = np.array((0, 0, 0, 1, 2, 3, 0, 2, 1, 0))
@@ -1625,6 +1629,12 @@ def trim_zeros(
1625
1629
trim = trim .lower ()
1626
1630
1627
1631
absolutes = np .abs (filt )
1632
+ if atol > 0 :
1633
+ absolutes [absolutes <= atol ] = 0
1634
+ if rtol > 0 :
1635
+ absolutes [absolutes <= rtol * absolutes .max ()] = 0
1636
+ nonzero = np .nonzero (absolutes )
1637
+
1628
1638
lengths = np .zeros ((absolutes .ndim , 2 ), dtype = np .intp )
1629
1639
1630
1640
if axis is None :
@@ -1635,28 +1645,19 @@ def trim_zeros(
1635
1645
if axis .ndim == 0 :
1636
1646
axis = np .asarray ([axis ], dtype = np .intp )
1637
1647
1638
- if atol > 0 :
1639
- absolutes [absolutes <= atol ] = 0
1640
- if rtol > 0 :
1641
- absolutes [absolutes <= rtol * absolutes .max ()] = 0
1642
-
1643
- nonzero = np .nonzero (absolutes )
1644
-
1645
1648
for current_axis in axis :
1646
- absolutes .take ([], current_axis ) # Raises if axis is out of bounds
1647
- if current_axis < 0 :
1648
- current_axis += absolutes .ndim
1649
+ current_axis = normalize_axis_index (current_axis , absolutes .ndim )
1649
1650
1650
1651
if nonzero [current_axis ].size > 0 :
1651
1652
start = nonzero [current_axis ].min ()
1652
- stop = nonzero [current_axis ].max ()
1653
- stop += 1
1653
+ stop = nonzero [current_axis ].max () + 1
1654
1654
else :
1655
- # In case the input is all-zero, slice only in front
1656
- start = stop = absolutes .shape [current_axis ]
1657
- if "f" not in trim :
1658
- # except when only the backside is to be sliced
1659
- stop = 0
1655
+ # In case the input is all-zero, slice depending on preference
1656
+ # given by user
1657
+ if trim .startswith ("b" ):
1658
+ start = stop = 0
1659
+ else :
1660
+ start = stop = absolutes .shape [current_axis ]
1660
1661
1661
1662
# Only slice on specified side(s)
1662
1663
if "f" not in trim :
@@ -1665,7 +1666,7 @@ def trim_zeros(
1665
1666
stop = None
1666
1667
1667
1668
# Use multi-dimensional slicing only when necessary, this allows
1668
- # preservation of the non-arrays input types
1669
+ # preservation of the non-array input types
1669
1670
sl = slice (start , stop )
1670
1671
if current_axis != 0 :
1671
1672
sl = (slice (None ),) * current_axis + (sl ,) + (...,)
0 commit comments