@@ -1624,46 +1624,45 @@ def trim_zeros(
1624
1624
"""
1625
1625
trim = trim .lower ()
1626
1626
1627
+ absolutes = np .abs (filt )
1628
+ lengths = np .zeros ((absolutes .ndim , 2 ), dtype = np .intp )
1629
+
1627
1630
if axis is None :
1628
1631
# Apply iteratively to all axes
1629
- axis = range (filt .ndim )
1630
-
1632
+ axis = range (absolutes .ndim )
1631
1633
# Normalize axes to 1D-array
1632
1634
axis = np .asarray (axis , dtype = np .intp )
1633
1635
if axis .ndim == 0 :
1634
1636
axis = np .asarray ([axis ], dtype = np .intp )
1635
1637
1636
- absolutes = np .abs (filt )
1637
- lengths = np .zeros ((absolutes .ndim , 2 ), dtype = np .intp )
1638
+ if atol > 0 :
1639
+ absolutes [absolutes <= atol ] = 0
1640
+ if rtol > 0 :
1641
+ absolutes [absolutes <= rtol * absolutes .max ()] = 0
1642
+
1643
+ nonzero = np .nonzero (absolutes )
1638
1644
1639
1645
for current_axis in axis :
1640
1646
absolutes .take ([], current_axis ) # Raises if axis is out of bounds
1641
1647
if current_axis < 0 :
1642
1648
current_axis += absolutes .ndim
1643
1649
1644
- # Reduce to envelope along all axes except the selected one
1645
- reduced = np .moveaxis (absolutes , current_axis , - 1 )
1646
- for _ in range (absolutes .ndim - 1 ):
1647
- reduced = reduced .max (axis = 0 )
1648
- assert reduced .ndim == 1
1649
-
1650
- if atol > 0 :
1651
- reduced [reduced <= atol ] = 0
1652
- if rtol > 0 :
1653
- reduced [reduced <= rtol * reduced .max ()] = 0
1654
-
1655
- # Find start and stop indices for current dimension
1656
- start , stop = np .nonzero (reduced )[0 ][[0 , - 1 ]]
1657
- stop += 1
1650
+ if nonzero [current_axis ].size > 0 :
1651
+ start = nonzero [current_axis ].min ()
1652
+ stop = nonzero [current_axis ].max ()
1653
+ stop += 1
1654
+ else :
1655
+ # In case the input is all-zero, slice only in front
1656
+ start = stop = absolutes .shape [current_axis ]
1657
+ if "f" not in trim :
1658
+ # except when only the backside is to be sliced
1659
+ stop = 0
1658
1660
1661
+ # Only slice on specified side(s)
1659
1662
if "f" not in trim :
1660
1663
start = None
1661
- else :
1662
- lengths [current_axis , 0 ] = start
1663
1664
if "b" not in trim :
1664
1665
stop = None
1665
- else :
1666
- lengths [current_axis , 1 ] = absolutes .shape [current_axis ] - stop
1667
1666
1668
1667
# Use multi-dimensional slicing only when necessary, this allows
1669
1668
# preservation of the non-arrays input types
@@ -1673,6 +1672,11 @@ def trim_zeros(
1673
1672
1674
1673
filt = filt [sl ]
1675
1674
1675
+ if start is not None :
1676
+ lengths [current_axis , 0 ] = start
1677
+ if stop is not None :
1678
+ lengths [current_axis , 1 ] = absolutes .shape [current_axis ] - stop
1679
+
1676
1680
if return_lengths is True :
1677
1681
return filt , lengths
1678
1682
else :
0 commit comments