@@ -1557,12 +1557,54 @@ def sort_complex(a):
1557
1557
return b
1558
1558
1559
1559
1560
- def _trim_zeros (filt , trim = None , axis = None , * , return_lengths = None ):
1560
+ def _arg_trim_zeros (filt , trim = None ):
1561
+ return (filt , filt )
1562
+
1563
+
1564
+ @array_function_dispatch (_arg_trim_zeros )
1565
+ def arg_trim_zeros (filt , trim = 'fb' ):
1566
+ """Return indices of the first and last non-zero element.
1567
+
1568
+ Parameters
1569
+ ----------
1570
+ filt : array_like
1571
+ Input array.
1572
+ trim : str, optional
1573
+ A string with 'f' representing trim from front and 'b' to trim from
1574
+ back. By default, zeros are trimmed from the front and back.
1575
+
1576
+ Returns
1577
+ -------
1578
+ start, stop : ndarray
1579
+ Two arrays containing the indices of the first and last non-zero
1580
+ element in each dimension.
1581
+
1582
+ See also
1583
+ --------
1584
+ trim_zeros
1585
+ """
1586
+ filt = np .asarray (filt )
1587
+ trim = trim .lower ()
1588
+
1589
+ nonzero = np .argwhere (filt )
1590
+ if nonzero .size == 0 :
1591
+ if trim .startswith ('b' ):
1592
+ start = stop = np .zeros (filt .ndim , dtype = np .intp )
1593
+ else :
1594
+ start = stop = np .array (filt .shape , dtype = np .intp )
1595
+ else :
1596
+ start = nonzero .min (axis = 0 )
1597
+ stop = nonzero .max (axis = 0 )
1598
+
1599
+ return start , stop
1600
+
1601
+
1602
+ def _trim_zeros (filt , trim = None , axis = None ):
1561
1603
return (filt ,)
1562
1604
1563
1605
1564
1606
@array_function_dispatch (_trim_zeros )
1565
- def trim_zeros (filt , trim = 'fb' , axis = - 1 , * , return_lengths = False ):
1607
+ def trim_zeros (filt , trim = 'fb' , axis = - 1 ):
1566
1608
"""Remove values along a dimension which are zero along all other.
1567
1609
1568
1610
Parameters
@@ -1573,23 +1615,21 @@ def trim_zeros(filt, trim='fb', axis=-1, *, return_lengths=False):
1573
1615
A string with 'f' representing trim from front and 'b' to trim from
1574
1616
back. By default, zeros are trimmed from the front and back.
1575
1617
axis : int or sequence, optional
1576
- The axis or a sequence of axes to trim. If None all axes are trimmed.
1577
- return_lengths : bool, optional
1578
- Additionally return the number of trimmed samples in each dimension at
1579
- the front and back.
1618
+ The axis to trim. If None all axes are trimmed.
1580
1619
1581
1620
Returns
1582
1621
-------
1583
1622
trimmed : ndarray or sequence
1584
1623
The result of trimming the input. The input data type is preserved.
1585
- lengths : ndarray
1586
- If `return_lengths` was True, an array of shape (``filt.ndim``, 2) is
1587
- returned. It contains the number of trimmed samples in each dimension
1588
- at the front and back.
1624
+
1625
+ See also
1626
+ --------
1627
+ arg_trim_zeros
1589
1628
1590
1629
Notes
1591
1630
-----
1592
- For all-zero arrays, the first axis is trimmed first.
1631
+ For all-zero arrays, the first axis is trimmed depending on the order in
1632
+ `trim`.
1593
1633
1594
1634
Examples
1595
1635
--------
@@ -1606,57 +1646,21 @@ def trim_zeros(filt, trim='fb', axis=-1, *, return_lengths=False):
1606
1646
[1, 2]
1607
1647
1608
1648
"""
1609
- trim = trim .lower ()
1610
-
1611
- absolutes = np .abs (filt )
1612
- nonzero = np .nonzero (absolutes )
1613
- lengths = np .zeros ((absolutes .ndim , 2 ), dtype = np .intp )
1614
-
1615
- if axis is None :
1616
- # Apply iteratively to all axes
1617
- axis = range (absolutes .ndim )
1618
- # Normalize axes to 1D-array
1619
- axis = np .asarray (axis , dtype = np .intp )
1620
- if axis .ndim == 0 :
1621
- axis = np .asarray ([axis ], dtype = np .intp )
1622
-
1623
- for current_axis in axis :
1624
- current_axis = normalize_axis_index (current_axis , absolutes .ndim )
1625
-
1626
- if nonzero [current_axis ].size > 0 :
1627
- start = nonzero [current_axis ].min ()
1628
- stop = nonzero [current_axis ].max () + 1
1629
- else :
1630
- # In case the input is all-zero, slice depending on preference
1631
- # given by user
1632
- if trim .startswith ("b" ):
1633
- start = stop = 0
1634
- else :
1635
- start = stop = absolutes .shape [current_axis ]
1636
-
1637
- # Only slice on specified side(s)
1638
- if "f" not in trim :
1639
- start = None
1640
- if "b" not in trim :
1641
- stop = None
1642
-
1643
- # Use multi-dimensional slicing only when necessary, this allows
1644
- # preservation of the non-array input types
1645
- sl = slice (start , stop )
1646
- if current_axis != 0 :
1647
- sl = (slice (None ),) * current_axis + (sl ,) + (...,)
1648
-
1649
- filt = filt [sl ]
1650
-
1651
- if start is not None :
1652
- lengths [current_axis , 0 ] = start
1653
- if stop is not None :
1654
- lengths [current_axis , 1 ] = absolutes .shape [current_axis ] - stop
1655
-
1656
- if return_lengths is True :
1657
- return filt , lengths
1649
+ start , stop = arg_trim_zeros (filt , trim )
1650
+ stop += 1 # Adjust for slicing
1651
+
1652
+ if start .size == 1 :
1653
+ # filt is 1D -> use multi-dimensional slicing only when necessary,
1654
+ # this allows preservation of the non-array input types
1655
+ sl = slice (start [0 ], stop [0 ])
1656
+ elif axis is None :
1657
+ # trim all axes
1658
+ sl = tuple (slice (* x ) for x in zip (start , stop ))
1658
1659
else :
1659
- return filt
1660
+ # only trim given axis
1661
+ sl = (slice (None ),) * axis + (slice (start [axis ], stop [axis ]),) + (...,)
1662
+
1663
+ return filt [sl ]
1660
1664
1661
1665
1662
1666
def _extract_dispatcher (condition , arr ):
0 commit comments