Skip to content

Commit 2dca68b

Browse files
committed
Type-annotate pytest.warns
1 parent d7ee3da commit 2dca68b

File tree

1 file changed

+87
-23
lines changed

1 file changed

+87
-23
lines changed

src/_pytest/recwarn.py

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
""" recording warnings during test function execution. """
2-
import inspect
32
import re
43
import warnings
4+
from types import TracebackType
5+
from typing import Any
6+
from typing import Callable
7+
from typing import Iterator
8+
from typing import List
9+
from typing import Optional
10+
from typing import overload
11+
from typing import Pattern
12+
from typing import Tuple
13+
from typing import Union
514

615
from _pytest.fixtures import yield_fixture
716
from _pytest.outcomes import fail
817

18+
if False: # TYPE_CHECKING
19+
from typing import Type
20+
921

1022
@yield_fixture
1123
def recwarn():
@@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
4254
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
4355

4456

45-
def warns(expected_warning, *args, match=None, **kwargs):
57+
@overload
58+
def warns(
59+
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
60+
*,
61+
match: Optional[Union[str, Pattern]] = ...
62+
) -> "WarningsChecker":
63+
... # pragma: no cover
64+
65+
66+
@overload
67+
def warns(
68+
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
69+
func: Callable,
70+
*args: Any,
71+
match: Optional[Union[str, Pattern]] = ...,
72+
**kwargs: Any
73+
) -> Union[Any]:
74+
... # pragma: no cover
75+
76+
77+
def warns(
78+
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
79+
*args: Any,
80+
match: Optional[Union[str, Pattern]] = None,
81+
**kwargs: Any
82+
) -> Union["WarningsChecker", Any]:
4683
r"""Assert that code raises a particular class of warning.
4784
4885
Specifically, the parameter ``expected_warning`` can be a warning class or
@@ -101,81 +138,107 @@ class WarningsRecorder(warnings.catch_warnings):
101138
def __init__(self):
102139
super().__init__(record=True)
103140
self._entered = False
104-
self._list = []
141+
self._list = [] # type: List[warnings._Record]
105142

106143
@property
107-
def list(self):
144+
def list(self) -> List["warnings._Record"]:
108145
"""The list of recorded warnings."""
109146
return self._list
110147

111-
def __getitem__(self, i):
148+
def __getitem__(self, i: int) -> "warnings._Record":
112149
"""Get a recorded warning by index."""
113150
return self._list[i]
114151

115-
def __iter__(self):
152+
def __iter__(self) -> Iterator["warnings._Record"]:
116153
"""Iterate through the recorded warnings."""
117154
return iter(self._list)
118155

119-
def __len__(self):
156+
def __len__(self) -> int:
120157
"""The number of recorded warnings."""
121158
return len(self._list)
122159

123-
def pop(self, cls=Warning):
160+
def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
124161
"""Pop the first recorded warning, raise exception if not exists."""
125162
for i, w in enumerate(self._list):
126163
if issubclass(w.category, cls):
127164
return self._list.pop(i)
128165
__tracebackhide__ = True
129166
raise AssertionError("%r not found in warning list" % cls)
130167

131-
def clear(self):
168+
def clear(self) -> None:
132169
"""Clear the list of recorded warnings."""
133170
self._list[:] = []
134171

135-
def __enter__(self):
172+
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
173+
# -- it returns a List but we only emulate one.
174+
def __enter__(self) -> "WarningsRecorder": # type: ignore
136175
if self._entered:
137176
__tracebackhide__ = True
138177
raise RuntimeError("Cannot enter %r twice" % self)
139-
self._list = super().__enter__()
178+
_list = super().__enter__()
179+
# record=True means it's None.
180+
assert _list is not None
181+
self._list = _list
140182
warnings.simplefilter("always")
141183
return self
142184

143-
def __exit__(self, *exc_info):
185+
def __exit__(
186+
self,
187+
exc_type: Optional["Type[BaseException]"],
188+
exc_val: Optional[BaseException],
189+
exc_tb: Optional[TracebackType],
190+
) -> bool:
144191
if not self._entered:
145192
__tracebackhide__ = True
146193
raise RuntimeError("Cannot exit %r without entering first" % self)
147194

148-
super().__exit__(*exc_info)
195+
super().__exit__(exc_type, exc_val, exc_tb)
149196

150197
# Built-in catch_warnings does not reset entered state so we do it
151198
# manually here for this context manager to become reusable.
152199
self._entered = False
153200

201+
return False
202+
154203

155204
class WarningsChecker(WarningsRecorder):
156-
def __init__(self, expected_warning=None, match_expr=None):
205+
def __init__(
206+
self,
207+
expected_warning: Optional[
208+
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
209+
] = None,
210+
match_expr: Optional[Union[str, Pattern]] = None,
211+
) -> None:
157212
super().__init__()
158213

159214
msg = "exceptions must be derived from Warning, not %s"
160-
if isinstance(expected_warning, tuple):
215+
if expected_warning is None:
216+
expected_warning_tup = None
217+
elif isinstance(expected_warning, tuple):
161218
for exc in expected_warning:
162-
if not inspect.isclass(exc):
219+
if not issubclass(exc, Warning):
163220
raise TypeError(msg % type(exc))
164-
elif inspect.isclass(expected_warning):
165-
expected_warning = (expected_warning,)
166-
elif expected_warning is not None:
221+
expected_warning_tup = expected_warning
222+
elif issubclass(expected_warning, Warning):
223+
expected_warning_tup = (expected_warning,)
224+
else:
167225
raise TypeError(msg % type(expected_warning))
168226

169-
self.expected_warning = expected_warning
227+
self.expected_warning = expected_warning_tup
170228
self.match_expr = match_expr
171229

172-
def __exit__(self, *exc_info):
173-
super().__exit__(*exc_info)
230+
def __exit__(
231+
self,
232+
exc_type: Optional["Type[BaseException]"],
233+
exc_val: Optional[BaseException],
234+
exc_tb: Optional[TracebackType],
235+
) -> bool:
236+
super().__exit__(exc_type, exc_val, exc_tb)
174237

175238
__tracebackhide__ = True
176239

177240
# only check if we're not currently handling an exception
178-
if all(a is None for a in exc_info):
241+
if exc_type is None and exc_val is None and exc_tb is None:
179242
if self.expected_warning is not None:
180243
if not any(issubclass(r.category, self.expected_warning) for r in self):
181244
__tracebackhide__ = True
@@ -200,3 +263,4 @@ def __exit__(self, *exc_info):
200263
[each.message for each in self],
201264
)
202265
)
266+
return False

0 commit comments

Comments
 (0)