10
10
from contextlib import contextmanager
11
11
from inspect import Parameter
12
12
from inspect import signature
13
+ from typing import Any
13
14
from typing import Callable
14
15
from typing import Generic
15
16
from typing import Optional
16
17
from typing import overload
18
+ from typing import Tuple
17
19
from typing import TypeVar
20
+ from typing import Union
18
21
19
22
import attr
20
23
import py
40
43
41
44
42
45
if sys .version_info >= (3 , 8 ):
43
- from importlib import metadata as importlib_metadata # noqa: F401
46
+ # Type ignored until next mypy release.
47
+ from importlib import metadata as importlib_metadata # type: ignore
44
48
else :
45
49
import importlib_metadata # noqa: F401
46
50
47
51
48
- def _format_args (func ) :
52
+ def _format_args (func : Callable [..., Any ]) -> str :
49
53
return str (signature (func ))
50
54
51
55
@@ -66,12 +70,12 @@ def fspath(p):
66
70
fspath = os .fspath
67
71
68
72
69
- def is_generator (func ) :
73
+ def is_generator (func : object ) -> bool :
70
74
genfunc = inspect .isgeneratorfunction (func )
71
75
return genfunc and not iscoroutinefunction (func )
72
76
73
77
74
- def iscoroutinefunction (func ) :
78
+ def iscoroutinefunction (func : object ) -> bool :
75
79
"""
76
80
Return True if func is a coroutine function (a function defined with async
77
81
def syntax, and doesn't contain yield), or a function decorated with
@@ -84,7 +88,7 @@ def syntax, and doesn't contain yield), or a function decorated with
84
88
return inspect .iscoroutinefunction (func ) or getattr (func , "_is_coroutine" , False )
85
89
86
90
87
- def getlocation (function , curdir = None ):
91
+ def getlocation (function , curdir = None ) -> str :
88
92
function = get_real_func (function )
89
93
fn = py .path .local (inspect .getfile (function ))
90
94
lineno = function .__code__ .co_firstlineno
@@ -93,7 +97,7 @@ def getlocation(function, curdir=None):
93
97
return "%s:%d" % (fn , lineno + 1 )
94
98
95
99
96
- def num_mock_patch_args (function ):
100
+ def num_mock_patch_args (function ) -> int :
97
101
""" return number of arguments used up by mock arguments (if any) """
98
102
patchings = getattr (function , "patchings" , None )
99
103
if not patchings :
@@ -112,7 +116,13 @@ def num_mock_patch_args(function):
112
116
)
113
117
114
118
115
- def getfuncargnames (function , * , name : str = "" , is_method = False , cls = None ):
119
+ def getfuncargnames (
120
+ function : Callable [..., Any ],
121
+ * ,
122
+ name : str = "" ,
123
+ is_method : bool = False ,
124
+ cls : Optional [type ] = None
125
+ ) -> Tuple [str , ...]:
116
126
"""Returns the names of a function's mandatory arguments.
117
127
118
128
This should return the names of all function arguments that:
@@ -180,7 +190,7 @@ def nullcontext():
180
190
from contextlib import nullcontext # noqa
181
191
182
192
183
- def get_default_arg_names (function ) :
193
+ def get_default_arg_names (function : Callable [..., Any ]) -> Tuple [ str , ...] :
184
194
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
185
195
# to get the arguments which were excluded from its result because they had default values
186
196
return tuple (
@@ -199,18 +209,18 @@ def get_default_arg_names(function):
199
209
)
200
210
201
211
202
- def _translate_non_printable (s ) :
212
+ def _translate_non_printable (s : str ) -> str :
203
213
return s .translate (_non_printable_ascii_translate_table )
204
214
205
215
206
216
STRING_TYPES = bytes , str
207
217
208
218
209
- def _bytes_to_ascii (val ) :
219
+ def _bytes_to_ascii (val : bytes ) -> str :
210
220
return val .decode ("ascii" , "backslashreplace" )
211
221
212
222
213
- def ascii_escaped (val ):
223
+ def ascii_escaped (val : Union [ bytes , str ] ):
214
224
"""If val is pure ascii, returns it as a str(). Otherwise, escapes
215
225
bytes objects into a sequence of escaped bytes:
216
226
@@ -307,7 +317,7 @@ def getimfunc(func):
307
317
return func
308
318
309
319
310
- def safe_getattr (object , name , default ) :
320
+ def safe_getattr (object : Any , name : str , default : Any ) -> Any :
311
321
""" Like getattr but return default upon any Exception or any OutcomeException.
312
322
313
323
Attribute access can potentially fail for 'evil' Python objects.
@@ -321,7 +331,7 @@ def safe_getattr(object, name, default):
321
331
return default
322
332
323
333
324
- def safe_isclass (obj ) :
334
+ def safe_isclass (obj : object ) -> bool :
325
335
"""Ignore any exception via isinstance on Python 3."""
326
336
try :
327
337
return inspect .isclass (obj )
@@ -342,39 +352,26 @@ def safe_isclass(obj):
342
352
)
343
353
344
354
345
- def _setup_collect_fakemodule ():
355
+ def _setup_collect_fakemodule () -> None :
346
356
from types import ModuleType
347
357
import pytest
348
358
349
- pytest .collect = ModuleType ("pytest.collect" )
350
- pytest .collect .__all__ = [] # used for setns
359
+ # Types ignored because the module is created dynamically.
360
+ pytest .collect = ModuleType ("pytest.collect" ) # type: ignore
361
+ pytest .collect .__all__ = [] # type: ignore # used for setns
351
362
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES :
352
- setattr (pytest .collect , attr_name , getattr (pytest , attr_name ))
363
+ setattr (pytest .collect , attr_name , getattr (pytest , attr_name )) # type: ignore
353
364
354
365
355
366
class CaptureIO (io .TextIOWrapper ):
356
- def __init__ (self ):
367
+ def __init__ (self ) -> None :
357
368
super ().__init__ (io .BytesIO (), encoding = "UTF-8" , newline = "" , write_through = True )
358
369
359
- def getvalue (self ):
370
+ def getvalue (self ) -> str :
371
+ assert isinstance (self .buffer , io .BytesIO )
360
372
return self .buffer .getvalue ().decode ("UTF-8" )
361
373
362
374
363
- class FuncargnamesCompatAttr :
364
- """ helper class so that Metafunc, Function and FixtureRequest
365
- don't need to each define the "funcargnames" compatibility attribute.
366
- """
367
-
368
- @property
369
- def funcargnames (self ):
370
- """ alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
371
- import warnings
372
- from _pytest .deprecated import FUNCARGNAMES
373
-
374
- warnings .warn (FUNCARGNAMES , stacklevel = 2 )
375
- return self .fixturenames
376
-
377
-
378
375
if sys .version_info < (3 , 5 , 2 ): # pragma: no cover
379
376
380
377
def overload (f ): # noqa: F811
@@ -407,7 +404,9 @@ def __get__(
407
404
raise NotImplementedError ()
408
405
409
406
@overload # noqa: F811
410
- def __get__ (self , instance : _S , owner : Optional ["Type[_S]" ] = ...) -> _T :
407
+ def __get__ ( # noqa: F811
408
+ self , instance : _S , owner : Optional ["Type[_S]" ] = ...
409
+ ) -> _T :
411
410
raise NotImplementedError ()
412
411
413
412
def __get__ (self , instance , owner = None ): # noqa: F811
0 commit comments