Skip to content

Commit c2f7624

Browse files
authored
Merge pull request #5673 from bluetech/type-annotations-3
1/X Fix check_untyped_defs = True mypy errors
2 parents f05ca74 + 7259c45 commit c2f7624

File tree

13 files changed

+196
-105
lines changed

13 files changed

+196
-105
lines changed

src/_pytest/_code/code.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
from inspect import CO_VARARGS
66
from inspect import CO_VARKEYWORDS
77
from traceback import format_exception_only
8+
from types import CodeType
89
from types import TracebackType
10+
from typing import Any
11+
from typing import Dict
912
from typing import Generic
13+
from typing import List
1014
from typing import Optional
1115
from typing import Pattern
16+
from typing import Set
1217
from typing import Tuple
1318
from typing import TypeVar
1419
from typing import Union
@@ -29,7 +34,7 @@
2934
class Code:
3035
""" wrapper around Python code objects """
3136

32-
def __init__(self, rawcode):
37+
def __init__(self, rawcode) -> None:
3338
if not hasattr(rawcode, "co_filename"):
3439
rawcode = getrawcode(rawcode)
3540
try:
@@ -38,7 +43,7 @@ def __init__(self, rawcode):
3843
self.name = rawcode.co_name
3944
except AttributeError:
4045
raise TypeError("not a code object: {!r}".format(rawcode))
41-
self.raw = rawcode
46+
self.raw = rawcode # type: CodeType
4247

4348
def __eq__(self, other):
4449
return self.raw == other.raw
@@ -351,7 +356,7 @@ def recursionindex(self):
351356
""" return the index of the frame/TracebackEntry where recursion
352357
originates if appropriate, None if no recursion occurred
353358
"""
354-
cache = {}
359+
cache = {} # type: Dict[Tuple[Any, int, int], List[Dict[str, Any]]]
355360
for i, entry in enumerate(self):
356361
# id for the code.raw is needed to work around
357362
# the strange metaprogramming in the decorator lib from pypi
@@ -650,7 +655,7 @@ def repr_args(self, entry):
650655
args.append((argname, saferepr(argvalue)))
651656
return ReprFuncArgs(args)
652657

653-
def get_source(self, source, line_index=-1, excinfo=None, short=False):
658+
def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]:
654659
""" return formatted and marked up source lines. """
655660
import _pytest._code
656661

@@ -722,7 +727,7 @@ def repr_traceback_entry(self, entry, excinfo=None):
722727
else:
723728
line_index = entry.lineno - entry.getfirstlinesource()
724729

725-
lines = []
730+
lines = [] # type: List[str]
726731
style = entry._repr_style
727732
if style is None:
728733
style = self.style
@@ -799,7 +804,7 @@ def _truncate_recursive_traceback(self, traceback):
799804
exc_msg=str(e),
800805
max_frames=max_frames,
801806
total=len(traceback),
802-
)
807+
) # type: Optional[str]
803808
traceback = traceback[:max_frames] + traceback[-max_frames:]
804809
else:
805810
if recursionindex is not None:
@@ -812,10 +817,12 @@ def _truncate_recursive_traceback(self, traceback):
812817

813818
def repr_excinfo(self, excinfo):
814819

815-
repr_chain = []
820+
repr_chain = (
821+
[]
822+
) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]]
816823
e = excinfo.value
817824
descr = None
818-
seen = set()
825+
seen = set() # type: Set[int]
819826
while e is not None and id(e) not in seen:
820827
seen.add(id(e))
821828
if excinfo:
@@ -868,8 +875,8 @@ def __repr__(self):
868875

869876

870877
class ExceptionRepr(TerminalRepr):
871-
def __init__(self):
872-
self.sections = []
878+
def __init__(self) -> None:
879+
self.sections = [] # type: List[Tuple[str, str, str]]
873880

874881
def addsection(self, name, content, sep="-"):
875882
self.sections.append((name, content, sep))

src/_pytest/_code/source.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import warnings
88
from ast import PyCF_ONLY_AST as _AST_FLAG
99
from bisect import bisect_right
10+
from typing import List
1011

1112
import py
1213

@@ -19,11 +20,11 @@ class Source:
1920
_compilecounter = 0
2021

2122
def __init__(self, *parts, **kwargs):
22-
self.lines = lines = []
23+
self.lines = lines = [] # type: List[str]
2324
de = kwargs.get("deindent", True)
2425
for part in parts:
2526
if not part:
26-
partlines = []
27+
partlines = [] # type: List[str]
2728
elif isinstance(part, Source):
2829
partlines = part.lines
2930
elif isinstance(part, (tuple, list)):
@@ -157,8 +158,7 @@ def compile(
157158
source = "\n".join(self.lines) + "\n"
158159
try:
159160
co = compile(source, filename, mode, flag)
160-
except SyntaxError:
161-
ex = sys.exc_info()[1]
161+
except SyntaxError as ex:
162162
# re-represent syntax errors from parsing python strings
163163
msglines = self.lines[: ex.lineno]
164164
if ex.offset:
@@ -173,7 +173,8 @@ def compile(
173173
if flag & _AST_FLAG:
174174
return co
175175
lines = [(x + "\n") for x in self.lines]
176-
linecache.cache[filename] = (1, None, lines, filename)
176+
# Type ignored because linecache.cache is private.
177+
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
177178
return co
178179

179180

@@ -282,7 +283,7 @@ def get_statement_startend2(lineno, node):
282283
return start, end
283284

284285

285-
def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
286+
def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None):
286287
if astnode is None:
287288
content = str(source)
288289
# See #4260:

src/_pytest/assertion/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
support for presenting detailed information in failing assertions.
33
"""
44
import sys
5+
from typing import Optional
56

67
from _pytest.assertion import rewrite
78
from _pytest.assertion import truncate
@@ -52,7 +53,9 @@ def register_assert_rewrite(*names):
5253
importhook = hook
5354
break
5455
else:
55-
importhook = DummyRewriteHook()
56+
# TODO(typing): Add a protocol for mark_rewrite() and use it
57+
# for importhook and for PytestPluginManager.rewrite_hook.
58+
importhook = DummyRewriteHook() # type: ignore
5659
importhook.mark_rewrite(*names)
5760

5861

@@ -69,7 +72,7 @@ class AssertionState:
6972
def __init__(self, config, mode):
7073
self.mode = mode
7174
self.trace = config.trace.root.get("assertion")
72-
self.hook = None
75+
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
7376

7477

7578
def install_importhook(config):
@@ -108,6 +111,7 @@ def pytest_runtest_setup(item):
108111
"""
109112

110113
def callbinrepr(op, left, right):
114+
# type: (str, object, object) -> Optional[str]
111115
"""Call the pytest_assertrepr_compare hook and prepare the result
112116
113117
This uses the first result from the hook and then ensures the
@@ -133,12 +137,13 @@ def callbinrepr(op, left, right):
133137
if item.config.getvalue("assertmode") == "rewrite":
134138
res = res.replace("%", "%%")
135139
return res
140+
return None
136141

137142
util._reprcompare = callbinrepr
138143

139144
if item.ihook.pytest_assertion_pass.get_hookimpls():
140145

141-
def call_assertion_pass_hook(lineno, expl, orig):
146+
def call_assertion_pass_hook(lineno, orig, expl):
142147
item.ihook.pytest_assertion_pass(
143148
item=item, lineno=lineno, orig=orig, expl=expl
144149
)

src/_pytest/assertion/rewrite.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ast
33
import errno
44
import functools
5+
import importlib.abc
56
import importlib.machinery
67
import importlib.util
78
import io
@@ -16,6 +17,7 @@
1617
from typing import List
1718
from typing import Optional
1819
from typing import Set
20+
from typing import Tuple
1921

2022
import atomicwrites
2123

@@ -37,7 +39,7 @@
3739
AST_NONE = ast.NameConstant(None)
3840

3941

40-
class AssertionRewritingHook:
42+
class AssertionRewritingHook(importlib.abc.MetaPathFinder):
4143
"""PEP302/PEP451 import hook which rewrites asserts."""
4244

4345
def __init__(self, config):
@@ -47,13 +49,13 @@ def __init__(self, config):
4749
except ValueError:
4850
self.fnpats = ["test_*.py", "*_test.py"]
4951
self.session = None
50-
self._rewritten_names = set()
51-
self._must_rewrite = set()
52+
self._rewritten_names = set() # type: Set[str]
53+
self._must_rewrite = set() # type: Set[str]
5254
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
5355
# which might result in infinite recursion (#3506)
5456
self._writing_pyc = False
5557
self._basenames_to_check_rewrite = {"conftest"}
56-
self._marked_for_rewrite_cache = {}
58+
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
5759
self._session_paths_checked = False
5860

5961
def set_session(self, session):
@@ -202,7 +204,7 @@ def _should_rewrite(self, name, fn, state):
202204

203205
return self._is_marked_for_rewrite(name, state)
204206

205-
def _is_marked_for_rewrite(self, name, state):
207+
def _is_marked_for_rewrite(self, name: str, state):
206208
try:
207209
return self._marked_for_rewrite_cache[name]
208210
except KeyError:
@@ -217,7 +219,7 @@ def _is_marked_for_rewrite(self, name, state):
217219
self._marked_for_rewrite_cache[name] = False
218220
return False
219221

220-
def mark_rewrite(self, *names):
222+
def mark_rewrite(self, *names: str) -> None:
221223
"""Mark import names as needing to be rewritten.
222224
223225
The named module or package as well as any nested modules will
@@ -384,6 +386,7 @@ def _format_boolop(explanations, is_or):
384386

385387

386388
def _call_reprcompare(ops, results, expls, each_obj):
389+
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
387390
for i, res, expl in zip(range(len(ops)), results, expls):
388391
try:
389392
done = not res
@@ -399,11 +402,13 @@ def _call_reprcompare(ops, results, expls, each_obj):
399402

400403

401404
def _call_assertion_pass(lineno, orig, expl):
405+
# type: (int, str, str) -> None
402406
if util._assertion_pass is not None:
403-
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
407+
util._assertion_pass(lineno, orig, expl)
404408

405409

406410
def _check_if_assertion_pass_impl():
411+
# type: () -> bool
407412
"""Checks if any plugins implement the pytest_assertion_pass hook
408413
in order not to generate explanation unecessarily (might be expensive)"""
409414
return True if util._assertion_pass else False
@@ -577,7 +582,7 @@ def __init__(self, module_path, config, source):
577582
def _assert_expr_to_lineno(self):
578583
return _get_assertion_exprs(self.source)
579584

580-
def run(self, mod):
585+
def run(self, mod: ast.Module) -> None:
581586
"""Find all assert statements in *mod* and rewrite them."""
582587
if not mod.body:
583588
# Nothing to do.
@@ -619,12 +624,12 @@ def run(self, mod):
619624
]
620625
mod.body[pos:pos] = imports
621626
# Collect asserts.
622-
nodes = [mod]
627+
nodes = [mod] # type: List[ast.AST]
623628
while nodes:
624629
node = nodes.pop()
625630
for name, field in ast.iter_fields(node):
626631
if isinstance(field, list):
627-
new = []
632+
new = [] # type: List
628633
for i, child in enumerate(field):
629634
if isinstance(child, ast.Assert):
630635
# Transform assert.
@@ -698,7 +703,7 @@ def push_format_context(self):
698703
.explanation_param().
699704
700705
"""
701-
self.explanation_specifiers = {}
706+
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
702707
self.stack.append(self.explanation_specifiers)
703708

704709
def pop_format_context(self, expl_expr):
@@ -741,7 +746,8 @@ def visit_Assert(self, assert_):
741746
from _pytest.warning_types import PytestAssertRewriteWarning
742747
import warnings
743748

744-
warnings.warn_explicit(
749+
# Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121
750+
warnings.warn_explicit( # type: ignore
745751
PytestAssertRewriteWarning(
746752
"assertion is always true, perhaps remove parentheses?"
747753
),
@@ -750,15 +756,15 @@ def visit_Assert(self, assert_):
750756
lineno=assert_.lineno,
751757
)
752758

753-
self.statements = []
754-
self.variables = []
759+
self.statements = [] # type: List[ast.stmt]
760+
self.variables = [] # type: List[str]
755761
self.variable_counter = itertools.count()
756762

757763
if self.enable_assertion_pass_hook:
758-
self.format_variables = []
764+
self.format_variables = [] # type: List[str]
759765

760-
self.stack = []
761-
self.expl_stmts = []
766+
self.stack = [] # type: List[Dict[str, ast.expr]]
767+
self.expl_stmts = [] # type: List[ast.stmt]
762768
self.push_format_context()
763769
# Rewrite assert into a bunch of statements.
764770
top_condition, explanation = self.visit(assert_.test)
@@ -896,7 +902,7 @@ def visit_BoolOp(self, boolop):
896902
# Process each operand, short-circuiting if needed.
897903
for i, v in enumerate(boolop.values):
898904
if i:
899-
fail_inner = []
905+
fail_inner = [] # type: List[ast.stmt]
900906
# cond is set in a prior loop iteration below
901907
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
902908
self.expl_stmts = fail_inner
@@ -907,10 +913,10 @@ def visit_BoolOp(self, boolop):
907913
call = ast.Call(app, [expl_format], [])
908914
self.expl_stmts.append(ast.Expr(call))
909915
if i < levels:
910-
cond = res
916+
cond = res # type: ast.expr
911917
if is_or:
912918
cond = ast.UnaryOp(ast.Not(), cond)
913-
inner = []
919+
inner = [] # type: List[ast.stmt]
914920
self.statements.append(ast.If(cond, inner, []))
915921
self.statements = body = inner
916922
self.statements = save
@@ -976,7 +982,7 @@ def visit_Attribute(self, attr):
976982
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
977983
return res, expl
978984

979-
def visit_Compare(self, comp):
985+
def visit_Compare(self, comp: ast.Compare):
980986
self.push_format_context()
981987
left_res, left_expl = self.visit(comp.left)
982988
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
@@ -1009,7 +1015,7 @@ def visit_Compare(self, comp):
10091015
ast.Tuple(results, ast.Load()),
10101016
)
10111017
if len(comp.ops) > 1:
1012-
res = ast.BoolOp(ast.And(), load_names)
1018+
res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
10131019
else:
10141020
res = load_names[0]
10151021
return res, self.explanation_param(self.pop_format_context(expl_call))

0 commit comments

Comments
 (0)