10
10
from contextlib import contextmanager
11
11
from inspect import Parameter
12
12
from inspect import signature
13
+ from typing import Any
13
14
from typing import Callable
14
15
from typing import Generic
15
16
from typing import Optional
16
17
from typing import overload
18
+ from typing import Tuple
17
19
from typing import TypeVar
20
+ from typing import Union
18
21
19
22
import attr
20
23
import py
46
49
import importlib_metadata # noqa: F401
47
50
48
51
49
- def _format_args (func ) :
52
+ def _format_args (func : Callable [..., Any ]) -> str :
50
53
return str (signature (func ))
51
54
52
55
@@ -67,12 +70,12 @@ def fspath(p):
67
70
fspath = os .fspath
68
71
69
72
70
- def is_generator (func ) :
73
+ def is_generator (func : object ) -> bool :
71
74
genfunc = inspect .isgeneratorfunction (func )
72
75
return genfunc and not iscoroutinefunction (func )
73
76
74
77
75
- def iscoroutinefunction (func ) :
78
+ def iscoroutinefunction (func : object ) -> bool :
76
79
"""
77
80
Return True if func is a coroutine function (a function defined with async
78
81
def syntax, and doesn't contain yield), or a function decorated with
@@ -85,7 +88,7 @@ def syntax, and doesn't contain yield), or a function decorated with
85
88
return inspect .iscoroutinefunction (func ) or getattr (func , "_is_coroutine" , False )
86
89
87
90
88
- def getlocation (function , curdir = None ):
91
+ def getlocation (function , curdir = None ) -> str :
89
92
function = get_real_func (function )
90
93
fn = py .path .local (inspect .getfile (function ))
91
94
lineno = function .__code__ .co_firstlineno
@@ -94,7 +97,7 @@ def getlocation(function, curdir=None):
94
97
return "%s:%d" % (fn , lineno + 1 )
95
98
96
99
97
- def num_mock_patch_args (function ):
100
+ def num_mock_patch_args (function ) -> int :
98
101
""" return number of arguments used up by mock arguments (if any) """
99
102
patchings = getattr (function , "patchings" , None )
100
103
if not patchings :
@@ -113,7 +116,13 @@ def num_mock_patch_args(function):
113
116
)
114
117
115
118
116
- def getfuncargnames (function , * , name : str = "" , is_method = False , cls = None ):
119
+ def getfuncargnames (
120
+ function : Callable [..., Any ],
121
+ * ,
122
+ name : str = "" ,
123
+ is_method : bool = False ,
124
+ cls : Optional [type ] = None
125
+ ) -> Tuple [str , ...]:
117
126
"""Returns the names of a function's mandatory arguments.
118
127
119
128
This should return the names of all function arguments that:
@@ -181,7 +190,7 @@ def nullcontext():
181
190
from contextlib import nullcontext # noqa
182
191
183
192
184
- def get_default_arg_names (function ) :
193
+ def get_default_arg_names (function : Callable [..., Any ]) -> Tuple [ str , ...] :
185
194
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
186
195
# to get the arguments which were excluded from its result because they had default values
187
196
return tuple (
@@ -200,18 +209,18 @@ def get_default_arg_names(function):
200
209
)
201
210
202
211
203
- def _translate_non_printable (s ) :
212
+ def _translate_non_printable (s : str ) -> str :
204
213
return s .translate (_non_printable_ascii_translate_table )
205
214
206
215
207
216
STRING_TYPES = bytes , str
208
217
209
218
210
- def _bytes_to_ascii (val ) :
219
+ def _bytes_to_ascii (val : bytes ) -> str :
211
220
return val .decode ("ascii" , "backslashreplace" )
212
221
213
222
214
- def ascii_escaped (val ):
223
+ def ascii_escaped (val : Union [ bytes , str ] ):
215
224
"""If val is pure ascii, returns it as a str(). Otherwise, escapes
216
225
bytes objects into a sequence of escaped bytes:
217
226
@@ -308,7 +317,7 @@ def getimfunc(func):
308
317
return func
309
318
310
319
311
- def safe_getattr (object , name , default ) :
320
+ def safe_getattr (object : Any , name : str , default : Any ) -> Any :
312
321
""" Like getattr but return default upon any Exception or any OutcomeException.
313
322
314
323
Attribute access can potentially fail for 'evil' Python objects.
@@ -322,7 +331,7 @@ def safe_getattr(object, name, default):
322
331
return default
323
332
324
333
325
- def safe_isclass (obj ) :
334
+ def safe_isclass (obj : object ) -> bool :
326
335
"""Ignore any exception via isinstance on Python 3."""
327
336
try :
328
337
return inspect .isclass (obj )
@@ -343,21 +352,23 @@ def safe_isclass(obj):
343
352
)
344
353
345
354
346
- def _setup_collect_fakemodule ():
355
+ def _setup_collect_fakemodule () -> None :
347
356
from types import ModuleType
348
357
import pytest
349
358
350
- pytest .collect = ModuleType ("pytest.collect" )
351
- pytest .collect .__all__ = [] # used for setns
359
+ # Types ignored because the module is created dynamically.
360
+ pytest .collect = ModuleType ("pytest.collect" ) # type: ignore
361
+ pytest .collect .__all__ = [] # type: ignore # used for setns
352
362
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES :
353
- setattr (pytest .collect , attr_name , getattr (pytest , attr_name ))
363
+ setattr (pytest .collect , attr_name , getattr (pytest , attr_name )) # type: ignore
354
364
355
365
356
366
class CaptureIO (io .TextIOWrapper ):
357
- def __init__ (self ):
367
+ def __init__ (self ) -> None :
358
368
super ().__init__ (io .BytesIO (), encoding = "UTF-8" , newline = "" , write_through = True )
359
369
360
- def getvalue (self ):
370
+ def getvalue (self ) -> str :
371
+ assert isinstance (self .buffer , io .BytesIO )
361
372
return self .buffer .getvalue ().decode ("UTF-8" )
362
373
363
374
0 commit comments