@@ -1557,28 +1557,55 @@ def sort_complex(a):
1557
1557
return b
1558
1558
1559
1559
1560
- def _trim_zeros (filt , trim = None ):
1560
+ def _trim_zeros (
1561
+ filt ,
1562
+ trim = None ,
1563
+ axis = None ,
1564
+ * ,
1565
+ atol = None ,
1566
+ rtol = None ,
1567
+ return_lengths = None
1568
+ ):
1561
1569
return (filt ,)
1562
1570
1563
1571
1564
1572
@array_function_dispatch (_trim_zeros )
1565
- def trim_zeros (filt , trim = 'fb' ):
1566
- """
1567
- Trim the leading and/or trailing zeros from a 1-D array or sequence.
1573
+ def trim_zeros (
1574
+ filt ,
1575
+ trim = 'fb' ,
1576
+ axis = - 1 ,
1577
+ * ,
1578
+ atol = 0 ,
1579
+ rtol = 0 ,
1580
+ return_lengths = False
1581
+ ):
1582
+ """Remove values along a dimension which are zero along all other.
1568
1583
1569
1584
Parameters
1570
1585
----------
1571
- filt : 1-D array or sequence
1586
+ filt : array_like
1572
1587
Input array.
1573
1588
trim : str, optional
1574
1589
A string with 'f' representing trim from front and 'b' to trim from
1575
- back. Default is 'fb', trim zeros from both front and back of the
1576
- array.
1590
+ back. By default, zeros are trimmed from the front and back.
1591
+ axis : int or sequence, optional
1592
+ The axis or a sequence of axes to trim. If None all axes are trimmed.
1593
+ atol : float, optional
1594
+ Absolute tolerance with which a value is considered for trimming.
1595
+ rtol : float, optional
1596
+ Relative tolerance with which a value is considered for trimming.
1597
+ return_lengths : bool, optional
1598
+ Additionally return the number of trimmed samples in each dimension at
1599
+ the front and back.
1577
1600
1578
1601
Returns
1579
1602
-------
1580
- trimmed : 1-D array or sequence
1603
+ trimmed : ndarray or sequence
1581
1604
The result of trimming the input. The input data type is preserved.
1605
+ lengths : ndarray
1606
+ If `return_lengths` was True, an array of shape (``filt.ndim``, 2) is
1607
+ returned. It contains the number of trimmed samples in each dimension
1608
+ at the front and back.
1582
1609
1583
1610
Examples
1584
1611
--------
@@ -1595,22 +1622,62 @@ def trim_zeros(filt, trim='fb'):
1595
1622
[1, 2]
1596
1623
1597
1624
"""
1598
- first = 0
1599
- trim = trim .upper ()
1600
- if 'F' in trim :
1601
- for i in filt :
1602
- if i != 0. :
1603
- break
1604
- else :
1605
- first = first + 1
1606
- last = len (filt )
1607
- if 'B' in trim :
1608
- for i in filt [::- 1 ]:
1609
- if i != 0. :
1610
- break
1611
- else :
1612
- last = last - 1
1613
- return filt [first :last ]
1625
+ trim = trim .lower ()
1626
+
1627
+ if axis is None :
1628
+ # Apply iteratively to all axes
1629
+ axis = range (filt .ndim )
1630
+
1631
+ # Normalize axes to 1D-array
1632
+ axis = np .asarray (axis , dtype = np .intp )
1633
+ if axis .ndim == 0 :
1634
+ axis = np .asarray ([axis ], dtype = np .intp )
1635
+
1636
+ absolutes = np .abs (filt )
1637
+ lengths = np .zeros ((absolutes .ndim , 2 ), dtype = np .intp )
1638
+
1639
+ for current_axis in axis :
1640
+ absolutes .take ([], current_axis ) # Raises if axis is out of bounds
1641
+ if current_axis < 0 :
1642
+ current_axis += absolutes .ndim
1643
+
1644
+ # Reduce to envelope along all axes except the selected one
1645
+ reduced = np .moveaxis (absolutes , current_axis , - 1 )
1646
+ for _ in range (absolutes .ndim - 1 ):
1647
+ reduced = reduced .max (axis = 0 )
1648
+ assert reduced .ndim == 1
1649
+
1650
+ if atol > 0 :
1651
+ reduced [reduced <= atol ] = 0
1652
+ if rtol > 0 :
1653
+ reduced [reduced <= rtol * reduced .max ()] = 0
1654
+
1655
+ # Find start and stop indices for current dimension
1656
+ start , stop = np .nonzero (reduced )[0 ][[0 , - 1 ]]
1657
+ stop += 1
1658
+
1659
+ if "f" not in trim :
1660
+ start = None
1661
+ else :
1662
+ lengths [current_axis , 0 ] = start
1663
+ if "b" not in trim :
1664
+ stop = None
1665
+ else :
1666
+ lengths [current_axis , 1 ] = absolutes .shape [current_axis ] - stop
1667
+
1668
+ # Use multi-dimensional slicing only when necessary, this allows
1669
+ # preservation of the non-arrays input types
1670
+ sl = slice (start , stop )
1671
+ if current_axis != 0 :
1672
+ sl = (slice (None ),) * current_axis + (sl ,) + (...,)
1673
+
1674
+ filt = filt [sl ]
1675
+
1676
+ if return_lengths is True :
1677
+ return filt , lengths
1678
+ else :
1679
+ return filt
1680
+
1614
1681
1615
1682
def _extract_dispatcher (condition , arr ):
1616
1683
return (condition , arr )
0 commit comments