1
1
""" recording warnings during test function execution. """
2
- import inspect
3
2
import re
4
3
import warnings
4
+ from types import TracebackType
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import Iterator
8
+ from typing import List
9
+ from typing import Optional
10
+ from typing import overload
11
+ from typing import Pattern
12
+ from typing import Tuple
13
+ from typing import Union
5
14
6
15
from _pytest .fixtures import yield_fixture
7
16
from _pytest .outcomes import fail
8
17
18
+ if False : # TYPE_CHECKING
19
+ from typing import Type
20
+
9
21
10
22
@yield_fixture
11
23
def recwarn ():
@@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
42
54
return warns ((DeprecationWarning , PendingDeprecationWarning ), * args , ** kwargs )
43
55
44
56
45
- def warns (expected_warning , * args , match = None , ** kwargs ):
57
+ @overload
58
+ def warns (
59
+ expected_warning : Union ["Type[Warning]" , Tuple ["Type[Warning]" , ...]],
60
+ * ,
61
+ match : Optional [Union [str , Pattern ]] = ...
62
+ ) -> "WarningsChecker" :
63
+ ... # pragma: no cover
64
+
65
+
66
+ @overload
67
+ def warns (
68
+ expected_warning : Union ["Type[Warning]" , Tuple ["Type[Warning]" , ...]],
69
+ func : Callable ,
70
+ * args : Any ,
71
+ match : Optional [Union [str , Pattern ]] = ...,
72
+ ** kwargs : Any
73
+ ) -> Union [Any ]:
74
+ ... # pragma: no cover
75
+
76
+
77
+ def warns (
78
+ expected_warning : Union ["Type[Warning]" , Tuple ["Type[Warning]" , ...]],
79
+ * args : Any ,
80
+ match : Optional [Union [str , Pattern ]] = None ,
81
+ ** kwargs : Any
82
+ ) -> Union ["WarningsChecker" , Any ]:
46
83
r"""Assert that code raises a particular class of warning.
47
84
48
85
Specifically, the parameter ``expected_warning`` can be a warning class or
@@ -101,81 +138,107 @@ class WarningsRecorder(warnings.catch_warnings):
101
138
def __init__ (self ):
102
139
super ().__init__ (record = True )
103
140
self ._entered = False
104
- self ._list = []
141
+ self ._list = [] # type: List[warnings._Record]
105
142
106
143
@property
107
- def list (self ):
144
+ def list (self ) -> List [ "warnings._Record" ] :
108
145
"""The list of recorded warnings."""
109
146
return self ._list
110
147
111
- def __getitem__ (self , i ) :
148
+ def __getitem__ (self , i : int ) -> "warnings._Record" :
112
149
"""Get a recorded warning by index."""
113
150
return self ._list [i ]
114
151
115
- def __iter__ (self ):
152
+ def __iter__ (self ) -> Iterator [ "warnings._Record" ] :
116
153
"""Iterate through the recorded warnings."""
117
154
return iter (self ._list )
118
155
119
- def __len__ (self ):
156
+ def __len__ (self ) -> int :
120
157
"""The number of recorded warnings."""
121
158
return len (self ._list )
122
159
123
- def pop (self , cls = Warning ):
160
+ def pop (self , cls : "Type[Warning]" = Warning ) -> "warnings._Record" :
124
161
"""Pop the first recorded warning, raise exception if not exists."""
125
162
for i , w in enumerate (self ._list ):
126
163
if issubclass (w .category , cls ):
127
164
return self ._list .pop (i )
128
165
__tracebackhide__ = True
129
166
raise AssertionError ("%r not found in warning list" % cls )
130
167
131
- def clear (self ):
168
+ def clear (self ) -> None :
132
169
"""Clear the list of recorded warnings."""
133
170
self ._list [:] = []
134
171
135
- def __enter__ (self ):
172
+ # Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
173
+ # -- it returns a List but we only emulate one.
174
+ def __enter__ (self ) -> "WarningsRecorder" : # type: ignore
136
175
if self ._entered :
137
176
__tracebackhide__ = True
138
177
raise RuntimeError ("Cannot enter %r twice" % self )
139
- self ._list = super ().__enter__ ()
178
+ _list = super ().__enter__ ()
179
+ # record=True means it's None.
180
+ assert _list is not None
181
+ self ._list = _list
140
182
warnings .simplefilter ("always" )
141
183
return self
142
184
143
- def __exit__ (self , * exc_info ):
185
+ def __exit__ (
186
+ self ,
187
+ exc_type : Optional ["Type[BaseException]" ],
188
+ exc_val : Optional [BaseException ],
189
+ exc_tb : Optional [TracebackType ],
190
+ ) -> bool :
144
191
if not self ._entered :
145
192
__tracebackhide__ = True
146
193
raise RuntimeError ("Cannot exit %r without entering first" % self )
147
194
148
- super ().__exit__ (* exc_info )
195
+ super ().__exit__ (exc_type , exc_val , exc_tb )
149
196
150
197
# Built-in catch_warnings does not reset entered state so we do it
151
198
# manually here for this context manager to become reusable.
152
199
self ._entered = False
153
200
201
+ return False
202
+
154
203
155
204
class WarningsChecker (WarningsRecorder ):
156
- def __init__ (self , expected_warning = None , match_expr = None ):
205
+ def __init__ (
206
+ self ,
207
+ expected_warning : Optional [
208
+ Union ["Type[Warning]" , Tuple ["Type[Warning]" , ...]]
209
+ ] = None ,
210
+ match_expr : Optional [Union [str , Pattern ]] = None ,
211
+ ) -> None :
157
212
super ().__init__ ()
158
213
159
214
msg = "exceptions must be derived from Warning, not %s"
160
- if isinstance (expected_warning , tuple ):
215
+ if expected_warning is None :
216
+ expected_warning_tup = None
217
+ elif isinstance (expected_warning , tuple ):
161
218
for exc in expected_warning :
162
- if not inspect . isclass (exc ):
219
+ if not issubclass (exc , Warning ):
163
220
raise TypeError (msg % type (exc ))
164
- elif inspect .isclass (expected_warning ):
165
- expected_warning = (expected_warning ,)
166
- elif expected_warning is not None :
221
+ expected_warning_tup = expected_warning
222
+ elif issubclass (expected_warning , Warning ):
223
+ expected_warning_tup = (expected_warning ,)
224
+ else :
167
225
raise TypeError (msg % type (expected_warning ))
168
226
169
- self .expected_warning = expected_warning
227
+ self .expected_warning = expected_warning_tup
170
228
self .match_expr = match_expr
171
229
172
- def __exit__ (self , * exc_info ):
173
- super ().__exit__ (* exc_info )
230
+ def __exit__ (
231
+ self ,
232
+ exc_type : Optional ["Type[BaseException]" ],
233
+ exc_val : Optional [BaseException ],
234
+ exc_tb : Optional [TracebackType ],
235
+ ) -> bool :
236
+ super ().__exit__ (exc_type , exc_val , exc_tb )
174
237
175
238
__tracebackhide__ = True
176
239
177
240
# only check if we're not currently handling an exception
178
- if all ( a is None for a in exc_info ) :
241
+ if exc_type is None and exc_val is None and exc_tb is None :
179
242
if self .expected_warning is not None :
180
243
if not any (issubclass (r .category , self .expected_warning ) for r in self ):
181
244
__tracebackhide__ = True
@@ -200,3 +263,4 @@ def __exit__(self, *exc_info):
200
263
[each .message for each in self ],
201
264
)
202
265
)
266
+ return False
0 commit comments