17
17
from typing import List
18
18
from typing import Optional
19
19
from typing import Set
20
+ from typing import Tuple
20
21
21
22
import atomicwrites
22
23
@@ -48,13 +49,13 @@ def __init__(self, config):
48
49
except ValueError :
49
50
self .fnpats = ["test_*.py" , "*_test.py" ]
50
51
self .session = None
51
- self ._rewritten_names = set ()
52
- self ._must_rewrite = set ()
52
+ self ._rewritten_names = set () # type: Set[str]
53
+ self ._must_rewrite = set () # type: Set[str]
53
54
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
54
55
# which might result in infinite recursion (#3506)
55
56
self ._writing_pyc = False
56
57
self ._basenames_to_check_rewrite = {"conftest" }
57
- self ._marked_for_rewrite_cache = {}
58
+ self ._marked_for_rewrite_cache = {} # type: Dict[str, bool]
58
59
self ._session_paths_checked = False
59
60
60
61
def set_session (self , session ):
@@ -203,7 +204,7 @@ def _should_rewrite(self, name, fn, state):
203
204
204
205
return self ._is_marked_for_rewrite (name , state )
205
206
206
- def _is_marked_for_rewrite (self , name , state ):
207
+ def _is_marked_for_rewrite (self , name : str , state ):
207
208
try :
208
209
return self ._marked_for_rewrite_cache [name ]
209
210
except KeyError :
@@ -218,7 +219,7 @@ def _is_marked_for_rewrite(self, name, state):
218
219
self ._marked_for_rewrite_cache [name ] = False
219
220
return False
220
221
221
- def mark_rewrite (self , * names ) :
222
+ def mark_rewrite (self , * names : str ) -> None :
222
223
"""Mark import names as needing to be rewritten.
223
224
224
225
The named module or package as well as any nested modules will
@@ -385,6 +386,7 @@ def _format_boolop(explanations, is_or):
385
386
386
387
387
388
def _call_reprcompare (ops , results , expls , each_obj ):
389
+ # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
388
390
for i , res , expl in zip (range (len (ops )), results , expls ):
389
391
try :
390
392
done = not res
@@ -400,11 +402,13 @@ def _call_reprcompare(ops, results, expls, each_obj):
400
402
401
403
402
404
def _call_assertion_pass (lineno , orig , expl ):
405
+ # type: (int, str, str) -> None
403
406
if util ._assertion_pass is not None :
404
- util ._assertion_pass (lineno = lineno , orig = orig , expl = expl )
407
+ util ._assertion_pass (lineno , orig , expl )
405
408
406
409
407
410
def _check_if_assertion_pass_impl ():
411
+ # type: () -> bool
408
412
"""Checks if any plugins implement the pytest_assertion_pass hook
409
413
in order not to generate explanation unecessarily (might be expensive)"""
410
414
return True if util ._assertion_pass else False
@@ -578,7 +582,7 @@ def __init__(self, module_path, config, source):
578
582
def _assert_expr_to_lineno (self ):
579
583
return _get_assertion_exprs (self .source )
580
584
581
- def run (self , mod ) :
585
+ def run (self , mod : ast . Module ) -> None :
582
586
"""Find all assert statements in *mod* and rewrite them."""
583
587
if not mod .body :
584
588
# Nothing to do.
@@ -620,12 +624,12 @@ def run(self, mod):
620
624
]
621
625
mod .body [pos :pos ] = imports
622
626
# Collect asserts.
623
- nodes = [mod ]
627
+ nodes = [mod ] # type: List[ast.AST]
624
628
while nodes :
625
629
node = nodes .pop ()
626
630
for name , field in ast .iter_fields (node ):
627
631
if isinstance (field , list ):
628
- new = []
632
+ new = [] # type: List
629
633
for i , child in enumerate (field ):
630
634
if isinstance (child , ast .Assert ):
631
635
# Transform assert.
@@ -699,7 +703,7 @@ def push_format_context(self):
699
703
.explanation_param().
700
704
701
705
"""
702
- self .explanation_specifiers = {}
706
+ self .explanation_specifiers = {} # type: Dict[str, ast.expr]
703
707
self .stack .append (self .explanation_specifiers )
704
708
705
709
def pop_format_context (self , expl_expr ):
@@ -742,7 +746,8 @@ def visit_Assert(self, assert_):
742
746
from _pytest .warning_types import PytestAssertRewriteWarning
743
747
import warnings
744
748
745
- warnings .warn_explicit (
749
+ # Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121
750
+ warnings .warn_explicit ( # type: ignore
746
751
PytestAssertRewriteWarning (
747
752
"assertion is always true, perhaps remove parentheses?"
748
753
),
@@ -751,15 +756,15 @@ def visit_Assert(self, assert_):
751
756
lineno = assert_ .lineno ,
752
757
)
753
758
754
- self .statements = []
755
- self .variables = []
759
+ self .statements = [] # type: List[ast.stmt]
760
+ self .variables = [] # type: List[str]
756
761
self .variable_counter = itertools .count ()
757
762
758
763
if self .enable_assertion_pass_hook :
759
- self .format_variables = []
764
+ self .format_variables = [] # type: List[str]
760
765
761
- self .stack = []
762
- self .expl_stmts = []
766
+ self .stack = [] # type: List[Dict[str, ast.expr]]
767
+ self .expl_stmts = [] # type: List[ast.stmt]
763
768
self .push_format_context ()
764
769
# Rewrite assert into a bunch of statements.
765
770
top_condition , explanation = self .visit (assert_ .test )
@@ -897,7 +902,7 @@ def visit_BoolOp(self, boolop):
897
902
# Process each operand, short-circuiting if needed.
898
903
for i , v in enumerate (boolop .values ):
899
904
if i :
900
- fail_inner = []
905
+ fail_inner = [] # type: List[ast.stmt]
901
906
# cond is set in a prior loop iteration below
902
907
self .expl_stmts .append (ast .If (cond , fail_inner , [])) # noqa
903
908
self .expl_stmts = fail_inner
@@ -908,10 +913,10 @@ def visit_BoolOp(self, boolop):
908
913
call = ast .Call (app , [expl_format ], [])
909
914
self .expl_stmts .append (ast .Expr (call ))
910
915
if i < levels :
911
- cond = res
916
+ cond = res # type: ast.expr
912
917
if is_or :
913
918
cond = ast .UnaryOp (ast .Not (), cond )
914
- inner = []
919
+ inner = [] # type: List[ast.stmt]
915
920
self .statements .append (ast .If (cond , inner , []))
916
921
self .statements = body = inner
917
922
self .statements = save
@@ -977,7 +982,7 @@ def visit_Attribute(self, attr):
977
982
expl = pat % (res_expl , res_expl , value_expl , attr .attr )
978
983
return res , expl
979
984
980
- def visit_Compare (self , comp ):
985
+ def visit_Compare (self , comp : ast . Compare ):
981
986
self .push_format_context ()
982
987
left_res , left_expl = self .visit (comp .left )
983
988
if isinstance (comp .left , (ast .Compare , ast .BoolOp )):
@@ -1010,7 +1015,7 @@ def visit_Compare(self, comp):
1010
1015
ast .Tuple (results , ast .Load ()),
1011
1016
)
1012
1017
if len (comp .ops ) > 1 :
1013
- res = ast .BoolOp (ast .And (), load_names )
1018
+ res = ast .BoolOp (ast .And (), load_names ) # type: ast.expr
1014
1019
else :
1015
1020
res = load_names [0 ]
1016
1021
return res , self .explanation_param (self .pop_format_context (expl_call ))
0 commit comments