Skip to content

Commit 562d481

Browse files
committed
Add type annotations to _pytest.compat
1 parent a649f15 commit 562d481

File tree

1 file changed

+29
-18
lines changed

1 file changed

+29
-18
lines changed

src/_pytest/compat.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from contextlib import contextmanager
1111
from inspect import Parameter
1212
from inspect import signature
13+
from typing import Any
1314
from typing import Callable
1415
from typing import Generic
1516
from typing import Optional
1617
from typing import overload
18+
from typing import Tuple
1719
from typing import TypeVar
20+
from typing import Union
1821

1922
import attr
2023
import py
@@ -46,7 +49,7 @@
4649
import importlib_metadata # noqa: F401
4750

4851

49-
def _format_args(func):
52+
def _format_args(func: Callable[..., Any]) -> str:
5053
return str(signature(func))
5154

5255

@@ -67,12 +70,12 @@ def fspath(p):
6770
fspath = os.fspath
6871

6972

70-
def is_generator(func):
73+
def is_generator(func: object) -> bool:
7174
genfunc = inspect.isgeneratorfunction(func)
7275
return genfunc and not iscoroutinefunction(func)
7376

7477

75-
def iscoroutinefunction(func):
78+
def iscoroutinefunction(func: object) -> bool:
7679
"""
7780
Return True if func is a coroutine function (a function defined with async
7881
def syntax, and doesn't contain yield), or a function decorated with
@@ -85,7 +88,7 @@ def syntax, and doesn't contain yield), or a function decorated with
8588
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
8689

8790

88-
def getlocation(function, curdir=None):
91+
def getlocation(function, curdir=None) -> str:
8992
function = get_real_func(function)
9093
fn = py.path.local(inspect.getfile(function))
9194
lineno = function.__code__.co_firstlineno
@@ -94,7 +97,7 @@ def getlocation(function, curdir=None):
9497
return "%s:%d" % (fn, lineno + 1)
9598

9699

97-
def num_mock_patch_args(function):
100+
def num_mock_patch_args(function) -> int:
98101
""" return number of arguments used up by mock arguments (if any) """
99102
patchings = getattr(function, "patchings", None)
100103
if not patchings:
@@ -113,7 +116,13 @@ def num_mock_patch_args(function):
113116
)
114117

115118

116-
def getfuncargnames(function, *, name: str = "", is_method=False, cls=None):
119+
def getfuncargnames(
120+
function: Callable[..., Any],
121+
*,
122+
name: str = "",
123+
is_method: bool = False,
124+
cls: Optional[type] = None
125+
) -> Tuple[str, ...]:
117126
"""Returns the names of a function's mandatory arguments.
118127
119128
This should return the names of all function arguments that:
@@ -181,7 +190,7 @@ def nullcontext():
181190
from contextlib import nullcontext # noqa
182191

183192

184-
def get_default_arg_names(function):
193+
def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
185194
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
186195
# to get the arguments which were excluded from its result because they had default values
187196
return tuple(
@@ -200,18 +209,18 @@ def get_default_arg_names(function):
200209
)
201210

202211

203-
def _translate_non_printable(s):
212+
def _translate_non_printable(s: str) -> str:
204213
return s.translate(_non_printable_ascii_translate_table)
205214

206215

207216
STRING_TYPES = bytes, str
208217

209218

210-
def _bytes_to_ascii(val):
219+
def _bytes_to_ascii(val: bytes) -> str:
211220
return val.decode("ascii", "backslashreplace")
212221

213222

214-
def ascii_escaped(val):
223+
def ascii_escaped(val: Union[bytes, str]):
215224
"""If val is pure ascii, returns it as a str(). Otherwise, escapes
216225
bytes objects into a sequence of escaped bytes:
217226
@@ -308,7 +317,7 @@ def getimfunc(func):
308317
return func
309318

310319

311-
def safe_getattr(object, name, default):
320+
def safe_getattr(object: Any, name: str, default: Any) -> Any:
312321
""" Like getattr but return default upon any Exception or any OutcomeException.
313322
314323
Attribute access can potentially fail for 'evil' Python objects.
@@ -322,7 +331,7 @@ def safe_getattr(object, name, default):
322331
return default
323332

324333

325-
def safe_isclass(obj):
334+
def safe_isclass(obj: object) -> bool:
326335
"""Ignore any exception via isinstance on Python 3."""
327336
try:
328337
return inspect.isclass(obj)
@@ -343,21 +352,23 @@ def safe_isclass(obj):
343352
)
344353

345354

346-
def _setup_collect_fakemodule():
355+
def _setup_collect_fakemodule() -> None:
347356
from types import ModuleType
348357
import pytest
349358

350-
pytest.collect = ModuleType("pytest.collect")
351-
pytest.collect.__all__ = [] # used for setns
359+
# Types ignored because the module is created dynamically.
360+
pytest.collect = ModuleType("pytest.collect") # type: ignore
361+
pytest.collect.__all__ = [] # type: ignore # used for setns
352362
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES:
353-
setattr(pytest.collect, attr_name, getattr(pytest, attr_name))
363+
setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore
354364

355365

356366
class CaptureIO(io.TextIOWrapper):
357-
def __init__(self):
367+
def __init__(self) -> None:
358368
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
359369

360-
def getvalue(self):
370+
def getvalue(self) -> str:
371+
assert isinstance(self.buffer, io.BytesIO)
361372
return self.buffer.getvalue().decode("UTF-8")
362373

363374

0 commit comments

Comments
 (0)