2
2
import ast
3
3
import errno
4
4
import functools
5
+ import importlib .abc
5
6
import importlib .machinery
6
7
import importlib .util
7
8
import io
16
17
from typing import List
17
18
from typing import Optional
18
19
from typing import Set
20
+ from typing import Tuple
19
21
20
22
import atomicwrites
21
23
37
39
AST_NONE = ast .NameConstant (None )
38
40
39
41
40
- class AssertionRewritingHook :
42
+ class AssertionRewritingHook ( importlib . abc . MetaPathFinder ) :
41
43
"""PEP302/PEP451 import hook which rewrites asserts."""
42
44
43
45
def __init__ (self , config ):
@@ -47,13 +49,13 @@ def __init__(self, config):
47
49
except ValueError :
48
50
self .fnpats = ["test_*.py" , "*_test.py" ]
49
51
self .session = None
50
- self ._rewritten_names = set ()
51
- self ._must_rewrite = set ()
52
+ self ._rewritten_names = set () # type: Set[str]
53
+ self ._must_rewrite = set () # type: Set[str]
52
54
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
53
55
# which might result in infinite recursion (#3506)
54
56
self ._writing_pyc = False
55
57
self ._basenames_to_check_rewrite = {"conftest" }
56
- self ._marked_for_rewrite_cache = {}
58
+ self ._marked_for_rewrite_cache = {} # type: Dict[str, bool]
57
59
self ._session_paths_checked = False
58
60
59
61
def set_session (self , session ):
@@ -202,7 +204,7 @@ def _should_rewrite(self, name, fn, state):
202
204
203
205
return self ._is_marked_for_rewrite (name , state )
204
206
205
- def _is_marked_for_rewrite (self , name , state ):
207
+ def _is_marked_for_rewrite (self , name : str , state ):
206
208
try :
207
209
return self ._marked_for_rewrite_cache [name ]
208
210
except KeyError :
@@ -217,7 +219,7 @@ def _is_marked_for_rewrite(self, name, state):
217
219
self ._marked_for_rewrite_cache [name ] = False
218
220
return False
219
221
220
- def mark_rewrite (self , * names ) :
222
+ def mark_rewrite (self , * names : str ) -> None :
221
223
"""Mark import names as needing to be rewritten.
222
224
223
225
The named module or package as well as any nested modules will
@@ -384,6 +386,7 @@ def _format_boolop(explanations, is_or):
384
386
385
387
386
388
def _call_reprcompare (ops , results , expls , each_obj ):
389
+ # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
387
390
for i , res , expl in zip (range (len (ops )), results , expls ):
388
391
try :
389
392
done = not res
@@ -399,11 +402,13 @@ def _call_reprcompare(ops, results, expls, each_obj):
399
402
400
403
401
404
def _call_assertion_pass (lineno , orig , expl ):
405
+ # type: (int, str, str) -> None
402
406
if util ._assertion_pass is not None :
403
- util ._assertion_pass (lineno = lineno , orig = orig , expl = expl )
407
+ util ._assertion_pass (lineno , orig , expl )
404
408
405
409
406
410
def _check_if_assertion_pass_impl ():
411
+ # type: () -> bool
407
412
"""Checks if any plugins implement the pytest_assertion_pass hook
408
413
in order not to generate explanation unecessarily (might be expensive)"""
409
414
return True if util ._assertion_pass else False
@@ -577,7 +582,7 @@ def __init__(self, module_path, config, source):
577
582
def _assert_expr_to_lineno (self ):
578
583
return _get_assertion_exprs (self .source )
579
584
580
- def run (self , mod ) :
585
+ def run (self , mod : ast . Module ) -> None :
581
586
"""Find all assert statements in *mod* and rewrite them."""
582
587
if not mod .body :
583
588
# Nothing to do.
@@ -619,12 +624,12 @@ def run(self, mod):
619
624
]
620
625
mod .body [pos :pos ] = imports
621
626
# Collect asserts.
622
- nodes = [mod ]
627
+ nodes = [mod ] # type: List[ast.AST]
623
628
while nodes :
624
629
node = nodes .pop ()
625
630
for name , field in ast .iter_fields (node ):
626
631
if isinstance (field , list ):
627
- new = []
632
+ new = [] # type: List
628
633
for i , child in enumerate (field ):
629
634
if isinstance (child , ast .Assert ):
630
635
# Transform assert.
@@ -698,7 +703,7 @@ def push_format_context(self):
698
703
.explanation_param().
699
704
700
705
"""
701
- self .explanation_specifiers = {}
706
+ self .explanation_specifiers = {} # type: Dict[str, ast.expr]
702
707
self .stack .append (self .explanation_specifiers )
703
708
704
709
def pop_format_context (self , expl_expr ):
@@ -741,7 +746,8 @@ def visit_Assert(self, assert_):
741
746
from _pytest .warning_types import PytestAssertRewriteWarning
742
747
import warnings
743
748
744
- warnings .warn_explicit (
749
+ # Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121
750
+ warnings .warn_explicit ( # type: ignore
745
751
PytestAssertRewriteWarning (
746
752
"assertion is always true, perhaps remove parentheses?"
747
753
),
@@ -750,15 +756,15 @@ def visit_Assert(self, assert_):
750
756
lineno = assert_ .lineno ,
751
757
)
752
758
753
- self .statements = []
754
- self .variables = []
759
+ self .statements = [] # type: List[ast.stmt]
760
+ self .variables = [] # type: List[str]
755
761
self .variable_counter = itertools .count ()
756
762
757
763
if self .enable_assertion_pass_hook :
758
- self .format_variables = []
764
+ self .format_variables = [] # type: List[str]
759
765
760
- self .stack = []
761
- self .expl_stmts = []
766
+ self .stack = [] # type: List[Dict[str, ast.expr]]
767
+ self .expl_stmts = [] # type: List[ast.stmt]
762
768
self .push_format_context ()
763
769
# Rewrite assert into a bunch of statements.
764
770
top_condition , explanation = self .visit (assert_ .test )
@@ -896,7 +902,7 @@ def visit_BoolOp(self, boolop):
896
902
# Process each operand, short-circuiting if needed.
897
903
for i , v in enumerate (boolop .values ):
898
904
if i :
899
- fail_inner = []
905
+ fail_inner = [] # type: List[ast.stmt]
900
906
# cond is set in a prior loop iteration below
901
907
self .expl_stmts .append (ast .If (cond , fail_inner , [])) # noqa
902
908
self .expl_stmts = fail_inner
@@ -907,10 +913,10 @@ def visit_BoolOp(self, boolop):
907
913
call = ast .Call (app , [expl_format ], [])
908
914
self .expl_stmts .append (ast .Expr (call ))
909
915
if i < levels :
910
- cond = res
916
+ cond = res # type: ast.expr
911
917
if is_or :
912
918
cond = ast .UnaryOp (ast .Not (), cond )
913
- inner = []
919
+ inner = [] # type: List[ast.stmt]
914
920
self .statements .append (ast .If (cond , inner , []))
915
921
self .statements = body = inner
916
922
self .statements = save
@@ -976,7 +982,7 @@ def visit_Attribute(self, attr):
976
982
expl = pat % (res_expl , res_expl , value_expl , attr .attr )
977
983
return res , expl
978
984
979
- def visit_Compare (self , comp ):
985
+ def visit_Compare (self , comp : ast . Compare ):
980
986
self .push_format_context ()
981
987
left_res , left_expl = self .visit (comp .left )
982
988
if isinstance (comp .left , (ast .Compare , ast .BoolOp )):
@@ -1009,7 +1015,7 @@ def visit_Compare(self, comp):
1009
1015
ast .Tuple (results , ast .Load ()),
1010
1016
)
1011
1017
if len (comp .ops ) > 1 :
1012
- res = ast .BoolOp (ast .And (), load_names )
1018
+ res = ast .BoolOp (ast .And (), load_names ) # type: ast.expr
1013
1019
else :
1014
1020
res = load_names [0 ]
1015
1021
return res , self .explanation_param (self .pop_format_context (expl_call ))
0 commit comments