@@ -1843,28 +1843,79 @@ def sort_complex(a):
1843
1843
return b
1844
1844
1845
1845
1846
- def _trim_zeros (filt , trim = None ):
1846
+ def _arg_trim_zeros (filt ):
1847
+ """Return indices of the first and last non-zero element.
1848
+
1849
+ Parameters
1850
+ ----------
1851
+ filt : array_like
1852
+ Input array.
1853
+
1854
+ Returns
1855
+ -------
1856
+ start, stop : ndarray
1857
+ Two arrays containing the indices of the first and last non-zero
1858
+ element in each dimension.
1859
+
1860
+ See also
1861
+ --------
1862
+ trim_zeros
1863
+
1864
+ Examples
1865
+ --------
1866
+ >>> import numpy as np
1867
+ >>> _arg_trim_zeros(np.array([0, 0, 1, 1, 0]))
1868
+ (array([2]), array([3]))
1869
+ """
1870
+ nonzero = (
1871
+ np .argwhere (filt )
1872
+ if filt .dtype != np .object_
1873
+ # Historically, `trim_zeros` treats `None` in an object array
1874
+ # as non-zero while argwhere doesn't, account for that
1875
+ else np .argwhere (filt != 0 )
1876
+ )
1877
+ if nonzero .size == 0 :
1878
+ start = stop = np .array ([], dtype = np .intp )
1879
+ else :
1880
+ start = nonzero .min (axis = 0 )
1881
+ stop = nonzero .max (axis = 0 )
1882
+ return start , stop
1883
+
1884
+
1885
+ def _trim_zeros (filt , trim = None , axis = None ):
1847
1886
return (filt ,)
1848
1887
1849
1888
1850
1889
@array_function_dispatch (_trim_zeros )
1851
- def trim_zeros (filt , trim = 'fb' ):
1852
- """
1853
- Trim the leading and/or trailing zeros from a 1-D array or sequence.
1890
+ def trim_zeros (filt , trim = 'fb' , axis = None ):
1891
+ """Remove values along a dimension which are zero along all other.
1854
1892
1855
1893
Parameters
1856
1894
----------
1857
- filt : 1-D array or sequence
1895
+ filt : array_like
1858
1896
Input array.
1859
- trim : str , optional
1897
+ trim : {"fb", "f", "b"} , optional
1860
1898
A string with 'f' representing trim from front and 'b' to trim from
1861
- back. Default is 'fb', trim zeros from both front and back of the
1862
- array.
1899
+ back. By default, zeros are trimmed on both sides.
1900
+ Front and back refer to the edges of a dimension, with "front" refering
1901
+ to the side with the lowest index 0, and "back" refering to the highest
1902
+ index (or index -1).
1903
+ axis : int or sequence, optional
1904
+ If None, `filt` is cropped such, that the smallest bounding box is
1905
+ returned that still contains all values which are not zero.
1906
+ If an axis is specified, `filt` will be sliced in that dimension only
1907
+ on the sides specified by `trim`. The remaining area will be the
1908
+ smallest that still contains all values wich are not zero.
1863
1909
1864
1910
Returns
1865
1911
-------
1866
- trimmed : 1-D array or sequence
1867
- The result of trimming the input. The input data type is preserved.
1912
+ trimmed : ndarray or sequence
1913
+ The result of trimming the input. The number of dimensions and the
1914
+ input data type are preserved.
1915
+
1916
+ Notes
1917
+ -----
1918
+ For all-zero arrays, the first axis is trimmed first.
1868
1919
1869
1920
Examples
1870
1921
--------
@@ -1873,32 +1924,63 @@ def trim_zeros(filt, trim='fb'):
1873
1924
>>> np.trim_zeros(a)
1874
1925
array([1, 2, 3, 0, 2, 1])
1875
1926
1876
- >>> np.trim_zeros(a, 'b')
1927
+ >>> np.trim_zeros(a, trim= 'b')
1877
1928
array([0, 0, 0, ..., 0, 2, 1])
1878
1929
1930
+ Multiple dimensions are supported.
1931
+
1932
+ >>> b = np.array([[0, 0, 2, 3, 0, 0],
1933
+ ... [0, 1, 0, 3, 0, 0],
1934
+ ... [0, 0, 0, 0, 0, 0]])
1935
+ >>> np.trim_zeros(b)
1936
+ array([[0, 2, 3],
1937
+ [1, 0, 3]])
1938
+
1939
+ >>> np.trim_zeros(b, axis=-1)
1940
+ array([[0, 2, 3],
1941
+ [1, 0, 3],
1942
+ [0, 0, 0]])
1943
+
1879
1944
The input data type is preserved, list/tuple in means list/tuple out.
1880
1945
1881
1946
>>> np.trim_zeros([0, 1, 2, 0])
1882
1947
[1, 2]
1883
1948
1884
1949
"""
1950
+ filt_ = np .asarray (filt )
1951
+
1952
+ trim = trim .lower ()
1953
+ if trim not in {"fb" , "bf" , "f" , "b" }:
1954
+ raise ValueError (f"unexpected character(s) in `trim`: { trim !r} " )
1955
+
1956
+ start , stop = _arg_trim_zeros (filt_ )
1957
+ stop += 1 # Adjust for slicing
1958
+
1959
+ if start .size == 0 :
1960
+ # filt is all-zero -> assign same values to start and stop so that
1961
+ # resulting slice will be empty
1962
+ start = stop = np .zeros (filt_ .ndim , dtype = np .intp )
1963
+ else :
1964
+ if 'f' not in trim :
1965
+ start = (None ,) * filt_ .ndim
1966
+ if 'b' not in trim :
1967
+ stop = (None ,) * filt_ .ndim
1968
+
1969
+ if len (start ) == 1 :
1970
+ # filt is 1D -> don't use multi-dimensional slicing to preserve
1971
+ # non-array input types
1972
+ sl = slice (start [0 ], stop [0 ])
1973
+ elif axis is None :
1974
+ # trim all axes
1975
+ sl = tuple (slice (* x ) for x in zip (start , stop ))
1976
+ else :
1977
+ # only trim single axis
1978
+ axis = normalize_axis_index (axis , filt_ .ndim )
1979
+ sl = (slice (None ),) * axis + (slice (start [axis ], stop [axis ]),) + (...,)
1980
+
1981
+ trimmed = filt [sl ]
1982
+ return trimmed
1885
1983
1886
- first = 0
1887
- trim = trim .upper ()
1888
- if 'F' in trim :
1889
- for i in filt :
1890
- if i != 0. :
1891
- break
1892
- else :
1893
- first = first + 1
1894
- last = len (filt )
1895
- if 'B' in trim :
1896
- for i in filt [::- 1 ]:
1897
- if i != 0. :
1898
- break
1899
- else :
1900
- last = last - 1
1901
- return filt [first :last ]
1902
1984
1903
1985
1904
1986
def _extract_dispatcher (condition , arr ):
0 commit comments