Skip to content

Commit a79706c

Browse files
committed
Make _arg_trim_zeros private and add tests
1 parent f69eca8 commit a79706c

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

numpy/lib/_function_base_impl.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,11 +1868,6 @@ def sort_complex(a):
18681868

18691869

18701870
def _arg_trim_zeros(filt):
1871-
return (filt, filt)
1872-
1873-
1874-
@array_function_dispatch(_arg_trim_zeros)
1875-
def arg_trim_zeros(filt):
18761871
"""Return indices of the first and last non-zero element.
18771872
18781873
Parameters
@@ -1889,6 +1884,12 @@ def arg_trim_zeros(filt):
18891884
See also
18901885
--------
18911886
trim_zeros
1887+
1888+
Examples
1889+
--------
1890+
>>> import numpy as np
1891+
>>> _arg_trim_zeros(np.array([0, 0, 1, 1, 0]))
1892+
(array([2]), array([3]))
18921893
"""
18931894
nonzero = np.argwhere(filt)
18941895
if nonzero.size == 0:
@@ -1922,14 +1923,9 @@ def trim_zeros(filt, trim='fb', axis=-1):
19221923
trimmed : ndarray or sequence
19231924
The result of trimming the input. The input data type is preserved.
19241925
1925-
See also
1926-
--------
1927-
arg_trim_zeros
1928-
19291926
Notes
19301927
-----
1931-
For all-zero arrays, the first axis is trimmed depending on the order in
1932-
`trim`.
1928+
For all-zero arrays, the first axis is trimmed first.
19331929
19341930
Examples
19351931
--------
@@ -1938,25 +1934,43 @@ def trim_zeros(filt, trim='fb', axis=-1):
19381934
>>> np.trim_zeros(a)
19391935
array([1, 2, 3, 0, 2, 1])
19401936
1941-
>>> np.trim_zeros(a, 'b')
1937+
>>> np.trim_zeros(a, trim='b')
19421938
array([0, 0, 0, ..., 0, 2, 1])
19431939
1940+
Multiple dimensions are supported.
1941+
1942+
>>> b = np.array([[0, 0, 2, 3, 0, 0],
1943+
... [0, 1, 0, 3, 0, 0],
1944+
... [0, 0, 0, 0, 0, 0]])
1945+
>>> np.trim_zeros(b)
1946+
array([[0, 2, 3],
1947+
[1, 0, 3]])
1948+
1949+
>>> np.trim_zeros(b, axis=-1)
1950+
array([[0, 2, 3],
1951+
[1, 0, 3],
1952+
[0, 0, 0]])
1953+
19441954
The input data type is preserved, list/tuple in means list/tuple out.
19451955
19461956
>>> np.trim_zeros([0, 1, 2, 0])
19471957
[1, 2]
19481958
19491959
"""
19501960
filt_ = np.asarray(filt)
1951-
start, stop = arg_trim_zeros(filt_)
1961+
1962+
trim = trim.lower()
1963+
if trim not in {"fb", "bf", "f", "b"}:
1964+
raise ValueError(f"unexpected character(s) in `trim`: {trim!r}")
1965+
1966+
start, stop = _arg_trim_zeros(filt_)
19521967
stop += 1 # Adjust for slicing
19531968

19541969
if start.size == 0:
19551970
# filt is all-zero -> assign same values to start and stop so that
19561971
# resulting slice will be empty
19571972
start = stop = np.zeros(filt_.ndim, dtype=np.intp)
19581973
else:
1959-
trim = trim.lower()
19601974
if 'f' not in trim:
19611975
start = (None,) * filt_.ndim
19621976
if 'b' not in trim:
@@ -1974,7 +1988,8 @@ def trim_zeros(filt, trim='fb', axis=-1):
19741988
axis = normalize_axis_index(axis, filt_.ndim)
19751989
sl = (slice(None),) * axis + (slice(start[axis], stop[axis]),) + (...,)
19761990

1977-
return filt[sl]
1991+
trimmed = filt[sl]
1992+
return trimmed
19781993

19791994

19801995

numpy/lib/tests/test_function_base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,12 @@ def test_trim_arg(self):
14221422
res = trim_zeros(a, trim='')
14231423
assert_array_equal(res, [0, 1, 2, 0])
14241424

1425+
@pytest.mark.parametrize("trim", ("front", ""))
1426+
def test_unexpected_trim_value(self, trim):
1427+
arr = self.a
1428+
with pytest.raises(ValueError, match=r"unexpected character\(s\) in `trim`"):
1429+
trim_zeros(arr, trim=trim)
1430+
14251431

14261432
class TestExtins:
14271433

0 commit comments

Comments
 (0)