Skip to content

Commit 6473b77

Browse files
authored
Merge pull request numpy#25999 from mattip/assert_warns
BUG: fix kwarg handling in assert_warn [skip cirrus][skip azp]
2 parents 64676ba + b081b4d commit 6473b77

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

numpy/testing/_private/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1949,8 +1949,15 @@ def assert_warns(warning_class, *args, **kwargs):
19491949
>>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4)
19501950
>>> assert ret == 16
19511951
"""
1952-
if not args:
1952+
if not args and not kwargs:
19531953
return _assert_warns_context(warning_class)
1954+
elif len(args) < 1:
1955+
if "match" in kwargs:
1956+
raise RuntimeError(
1957+
"assert_warns does not use 'match' kwarg, "
1958+
"use pytest.warns instead"
1959+
)
1960+
raise RuntimeError("assert_warns(...) needs at least one arg")
19541961

19551962
func = args[0]
19561963
args = args[1:]

numpy/testing/tests/test_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,27 @@ def no_warnings():
10351035
assert_equal(before_filters, after_filters,
10361036
"assert_warns does not preserver warnings state")
10371037

1038+
def test_args(self):
1039+
def f(a=0, b=1):
1040+
warnings.warn("yo")
1041+
return a + b
1042+
1043+
assert assert_warns(UserWarning, f, b=20) == 20
1044+
1045+
with pytest.raises(RuntimeError) as exc:
1046+
# assert_warns cannot do regexp matching, use pytest.warns
1047+
with assert_warns(UserWarning, match="A"):
1048+
warnings.warn("B", UserWarning)
1049+
assert "assert_warns" in str(exc)
1050+
assert "pytest.warns" in str(exc)
1051+
1052+
with pytest.raises(RuntimeError) as exc:
1053+
# assert_warns cannot do regexp matching, use pytest.warns
1054+
with assert_warns(UserWarning, wrong="A"):
1055+
warnings.warn("B", UserWarning)
1056+
assert "assert_warns" in str(exc)
1057+
assert "pytest.warns" not in str(exc)
1058+
10381059
def test_warn_wrong_warning(self):
10391060
def f():
10401061
warnings.warn("yo", DeprecationWarning)

0 commit comments

Comments
 (0)