@@ -1868,11 +1868,6 @@ def sort_complex(a):
1868
1868
1869
1869
1870
1870
def _arg_trim_zeros (filt ):
1871
- return (filt , filt )
1872
-
1873
-
1874
- @array_function_dispatch (_arg_trim_zeros )
1875
- def arg_trim_zeros (filt ):
1876
1871
"""Return indices of the first and last non-zero element.
1877
1872
1878
1873
Parameters
@@ -1889,6 +1884,12 @@ def arg_trim_zeros(filt):
1889
1884
See also
1890
1885
--------
1891
1886
trim_zeros
1887
+
1888
+ Examples
1889
+ --------
1890
+ >>> import numpy as np
1891
+ >>> _arg_trim_zeros(np.array([0, 0, 1, 1, 0]))
1892
+ (array([2]), array([3]))
1892
1893
"""
1893
1894
nonzero = np .argwhere (filt )
1894
1895
if nonzero .size == 0 :
@@ -1922,14 +1923,9 @@ def trim_zeros(filt, trim='fb', axis=-1):
1922
1923
trimmed : ndarray or sequence
1923
1924
The result of trimming the input. The input data type is preserved.
1924
1925
1925
- See also
1926
- --------
1927
- arg_trim_zeros
1928
-
1929
1926
Notes
1930
1927
-----
1931
- For all-zero arrays, the first axis is trimmed depending on the order in
1932
- `trim`.
1928
+ For all-zero arrays, the first axis is trimmed first.
1933
1929
1934
1930
Examples
1935
1931
--------
@@ -1938,25 +1934,43 @@ def trim_zeros(filt, trim='fb', axis=-1):
1938
1934
>>> np.trim_zeros(a)
1939
1935
array([1, 2, 3, 0, 2, 1])
1940
1936
1941
- >>> np.trim_zeros(a, 'b')
1937
+ >>> np.trim_zeros(a, trim= 'b')
1942
1938
array([0, 0, 0, ..., 0, 2, 1])
1943
1939
1940
+ Multiple dimensions are supported.
1941
+
1942
+ >>> b = np.array([[0, 0, 2, 3, 0, 0],
1943
+ ... [0, 1, 0, 3, 0, 0],
1944
+ ... [0, 0, 0, 0, 0, 0]])
1945
+ >>> np.trim_zeros(b)
1946
+ array([[0, 2, 3],
1947
+ [1, 0, 3]])
1948
+
1949
+ >>> np.trim_zeros(b, axis=-1)
1950
+ array([[0, 2, 3],
1951
+ [1, 0, 3],
1952
+ [0, 0, 0]])
1953
+
1944
1954
The input data type is preserved, list/tuple in means list/tuple out.
1945
1955
1946
1956
>>> np.trim_zeros([0, 1, 2, 0])
1947
1957
[1, 2]
1948
1958
1949
1959
"""
1950
1960
filt_ = np .asarray (filt )
1951
- start , stop = arg_trim_zeros (filt_ )
1961
+
1962
+ trim = trim .lower ()
1963
+ if trim not in {"fb" , "bf" , "f" , "b" }:
1964
+ raise ValueError (f"unexpected character(s) in `trim`: { trim !r} " )
1965
+
1966
+ start , stop = _arg_trim_zeros (filt_ )
1952
1967
stop += 1 # Adjust for slicing
1953
1968
1954
1969
if start .size == 0 :
1955
1970
# filt is all-zero -> assign same values to start and stop so that
1956
1971
# resulting slice will be empty
1957
1972
start = stop = np .zeros (filt_ .ndim , dtype = np .intp )
1958
1973
else :
1959
- trim = trim .lower ()
1960
1974
if 'f' not in trim :
1961
1975
start = (None ,) * filt_ .ndim
1962
1976
if 'b' not in trim :
@@ -1974,7 +1988,8 @@ def trim_zeros(filt, trim='fb', axis=-1):
1974
1988
axis = normalize_axis_index (axis , filt_ .ndim )
1975
1989
sl = (slice (None ),) * axis + (slice (start [axis ], stop [axis ]),) + (...,)
1976
1990
1977
- return filt [sl ]
1991
+ trimmed = filt [sl ]
1992
+ return trimmed
1978
1993
1979
1994
1980
1995
0 commit comments