Skip to content

Commit 117d5de

Browse files
committed
track if typing.TYPE_CHECKING to warn about non runtime bindings
When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context.
1 parent 95fe313 commit 117d5de

File tree

2 files changed

+155
-29
lines changed

2 files changed

+155
-29
lines changed

pyflakes/checker.py

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,11 @@ class Binding(object):
321321
the node that this binding was last used.
322322
"""
323323

324-
def __init__(self, name, source):
324+
def __init__(self, name, source, runtime=True):
325325
self.name = name
326326
self.source = source
327327
self.used = False
328+
self.runtime = runtime
328329

329330
def __str__(self):
330331
return self.name
@@ -391,10 +392,10 @@ class Importation(Definition):
391392
@type fullName: C{str}
392393
"""
393394

394-
def __init__(self, name, source, full_name=None):
395+
def __init__(self, name, source, full_name=None, runtime=True):
395396
self.fullName = full_name or name
396397
self.redefined = []
397-
super(Importation, self).__init__(name, source)
398+
super(Importation, self).__init__(name, source, runtime=runtime)
398399

399400
def redefines(self, other):
400401
if isinstance(other, SubmoduleImportation):
@@ -439,11 +440,12 @@ class SubmoduleImportation(Importation):
439440
name is also the same, to avoid false positives.
440441
"""
441442

442-
def __init__(self, name, source):
443+
def __init__(self, name, source, runtime=True):
443444
# A dot should only appear in the name when it is a submodule import
444445
assert '.' in name and (not source or isinstance(source, ast.Import))
445446
package_name = name.split('.')[0]
446-
super(SubmoduleImportation, self).__init__(package_name, source)
447+
super(SubmoduleImportation, self).__init__(
448+
package_name, source, runtime=runtime)
447449
self.fullName = name
448450

449451
def redefines(self, other):
@@ -461,7 +463,8 @@ def source_statement(self):
461463

462464
class ImportationFrom(Importation):
463465

464-
def __init__(self, name, source, module, real_name=None):
466+
def __init__(
467+
self, name, source, module, real_name=None, runtime=True):
465468
self.module = module
466469
self.real_name = real_name or name
467470

@@ -470,7 +473,8 @@ def __init__(self, name, source, module, real_name=None):
470473
else:
471474
full_name = module + '.' + self.real_name
472475

473-
super(ImportationFrom, self).__init__(name, source, full_name)
476+
super(ImportationFrom, self).__init__(
477+
name, source, full_name, runtime=runtime)
474478

475479
def __str__(self):
476480
"""Return import full name with alias."""
@@ -492,8 +496,8 @@ def source_statement(self):
492496
class StarImportation(Importation):
493497
"""A binding created by a 'from x import *' statement."""
494498

495-
def __init__(self, name, source):
496-
super(StarImportation, self).__init__('*', source)
499+
def __init__(self, name, source, runtime=True):
500+
super(StarImportation, self).__init__('*', source, runtime=runtime)
497501
# Each star importation needs a unique name, and
498502
# may not be the module name otherwise it will be deemed imported
499503
self.name = name + '.*'
@@ -576,7 +580,7 @@ class ExportBinding(Binding):
576580
C{__all__} will not have an unused import warning reported for them.
577581
"""
578582

579-
def __init__(self, name, source, scope):
583+
def __init__(self, name, source, scope, runtime=True):
580584
if '__all__' in scope and isinstance(source, ast.AugAssign):
581585
self.names = list(scope['__all__'].names)
582586
else:
@@ -607,7 +611,7 @@ def _add_to_names(container):
607611
# If not list concatenation
608612
else:
609613
break
610-
super(ExportBinding, self).__init__(name, source)
614+
super(ExportBinding, self).__init__(name, source, runtime=runtime)
611615

612616

613617
class Scope(dict):
@@ -871,6 +875,7 @@ class Checker(object):
871875
traceTree = False
872876
_in_annotation = AnnotationState.NONE
873877
_in_deferred = False
878+
_in_type_check_guard = False
874879

875880
builtIns = set(builtin_vars).union(_MAGIC_GLOBALS)
876881
_customBuiltIns = os.environ.get('PYFLAKES_BUILTINS')
@@ -1144,9 +1149,11 @@ def addBinding(self, node, value):
11441149
# then assume the rebound name is used as a global or within a loop
11451150
value.used = self.scope[value.name].used
11461151

1147-
# don't treat annotations as assignments if there is an existing value
1148-
# in scope
1149-
if value.name not in self.scope or not isinstance(value, Annotation):
1152+
# always allow the first assignment or if not already a runtime value,
1153+
# but do not shadow an existing assignment with an annotation or non
1154+
# runtime value.
1155+
if (not existing or not existing.runtime or (
1156+
not isinstance(value, Annotation) and value.runtime)):
11501157
self.scope[value.name] = value
11511158

11521159
def _unknown_handler(self, node):
@@ -1205,12 +1212,18 @@ def handleNodeLoad(self, node):
12051212
self.report(messages.InvalidPrintSyntax, node)
12061213

12071214
try:
1208-
scope[name].used = (self.scope, node)
1215+
n = scope[name]
1216+
if (not n.runtime and not (
1217+
self._in_type_check_guard
1218+
or self._in_annotation)):
1219+
self.report(messages.UndefinedName, node, name)
1220+
return
1221+
1222+
n.used = (self.scope, node)
12091223

12101224
# if the name of SubImportation is same as
12111225
# alias of other Importation and the alias
12121226
# is used, SubImportation also should be marked as used.
1213-
n = scope[name]
12141227
if isinstance(n, Importation) and n._has_alias():
12151228
try:
12161229
scope[n.fullName].used = (self.scope, node)
@@ -1273,18 +1286,20 @@ def handleNodeStore(self, node):
12731286
break
12741287

12751288
parent_stmt = self.getParent(node)
1289+
runtime = not self._in_type_check_guard
12761290
if isinstance(parent_stmt, ANNASSIGN_TYPES) and parent_stmt.value is None:
1277-
binding = Annotation(name, node)
1291+
binding = Annotation(name, node, runtime=runtime)
12781292
elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or (
12791293
parent_stmt != node._pyflakes_parent and
12801294
not self.isLiteralTupleUnpacking(parent_stmt)):
1281-
binding = Binding(name, node)
1295+
binding = Binding(name, node, runtime=runtime)
12821296
elif name == '__all__' and isinstance(self.scope, ModuleScope):
1283-
binding = ExportBinding(name, node._pyflakes_parent, self.scope)
1297+
binding = ExportBinding(
1298+
name, node._pyflakes_parent, self.scope, runtime=runtime)
12841299
elif PY2 and isinstance(getattr(node, 'ctx', None), ast.Param):
1285-
binding = Argument(name, self.getScopeNode(node))
1300+
binding = Argument(name, self.getScopeNode(node), runtime=runtime)
12861301
else:
1287-
binding = Assignment(name, node)
1302+
binding = Assignment(name, node, runtime=runtime)
12881303
self.addBinding(node, binding)
12891304

12901305
def handleNodeDelete(self, node):
@@ -1973,7 +1988,40 @@ def DICT(self, node):
19731988
def IF(self, node):
19741989
if isinstance(node.test, ast.Tuple) and node.test.elts != []:
19751990
self.report(messages.IfTuple, node)
1976-
self.handleChildren(node)
1991+
1992+
self._handle_type_comments(node)
1993+
self.handleNode(node.test, node)
1994+
1995+
# check if the body/orelse should be handled specially because it is
1996+
# a if TYPE_CHECKING guard.
1997+
test = node.test
1998+
reverse = False
1999+
if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not):
2000+
test = test.operand
2001+
reverse = True
2002+
2003+
type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack)
2004+
orig = self._in_type_check_guard
2005+
2006+
# normalize body and orelse to a list
2007+
body, orelse = (
2008+
i if isinstance(i, list) else [i]
2009+
for i in (node.body, node.orelse))
2010+
2011+
# set the guard and handle the body
2012+
if type_checking and not reverse:
2013+
self._in_type_check_guard = True
2014+
2015+
for n in body:
2016+
self.handleNode(n, node)
2017+
2018+
# set the guard and handle the orelse
2019+
if type_checking:
2020+
self._in_type_check_guard = True if reverse else orig
2021+
2022+
for n in orelse:
2023+
self.handleNode(n, node)
2024+
self._in_type_check_guard = orig
19772025

19782026
IFEXP = IF
19792027

@@ -2096,7 +2144,10 @@ def FUNCTIONDEF(self, node):
20962144
for deco in node.decorator_list:
20972145
self.handleNode(deco, node)
20982146
self.LAMBDA(node)
2099-
self.addBinding(node, FunctionDefinition(node.name, node))
2147+
self.addBinding(
2148+
node,
2149+
FunctionDefinition(
2150+
node.name, node, runtime=not self._in_type_check_guard))
21002151
# doctest does not process doctest within a doctest,
21012152
# or in nested functions.
21022153
if (self.withDoctest and
@@ -2221,7 +2272,10 @@ def CLASSDEF(self, node):
22212272
for stmt in node.body:
22222273
self.handleNode(stmt, node)
22232274
self.popScope()
2224-
self.addBinding(node, ClassDefinition(node.name, node))
2275+
self.addBinding(
2276+
node,
2277+
ClassDefinition(
2278+
node.name, node, runtime=not self._in_type_check_guard))
22252279

22262280
def AUGASSIGN(self, node):
22272281
self.handleNodeLoad(node.target)
@@ -2254,12 +2308,15 @@ def TUPLE(self, node):
22542308
LIST = TUPLE
22552309

22562310
def IMPORT(self, node):
2311+
runtime = not self._in_type_check_guard
22572312
for alias in node.names:
22582313
if '.' in alias.name and not alias.asname:
2259-
importation = SubmoduleImportation(alias.name, node)
2314+
importation = SubmoduleImportation(
2315+
alias.name, node, runtime=runtime)
22602316
else:
22612317
name = alias.asname or alias.name
2262-
importation = Importation(name, node, alias.name)
2318+
importation = Importation(
2319+
name, node, alias.name, runtime=runtime)
22632320
self.addBinding(node, importation)
22642321

22652322
def IMPORTFROM(self, node):
@@ -2272,6 +2329,7 @@ def IMPORTFROM(self, node):
22722329

22732330
module = ('.' * node.level) + (node.module or '')
22742331

2332+
runtime = not self._in_type_check_guard
22752333
for alias in node.names:
22762334
name = alias.asname or alias.name
22772335
if node.module == '__future__':
@@ -2290,10 +2348,10 @@ def IMPORTFROM(self, node):
22902348

22912349
self.scope.importStarred = True
22922350
self.report(messages.ImportStarUsed, node, module)
2293-
importation = StarImportation(module, node)
2351+
importation = StarImportation(module, node, runtime=runtime)
22942352
else:
2295-
importation = ImportationFrom(name, node,
2296-
module, alias.name)
2353+
importation = ImportationFrom(
2354+
name, node, module, alias.name, runtime=runtime)
22972355
self.addBinding(node, importation)
22982356

22992357
def TRY(self, node):

pyflakes/test/test_type_annotations.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,57 @@ def f(): # type: () -> T
690690
pass
691691
""")
692692

693+
@skipIf(version_info < (3,), 'new in Python 3')
694+
def test_typing_guard_import(self):
695+
# T is imported for runtime use
696+
self.flakes("""
697+
from typing import TYPE_CHECKING
698+
699+
if TYPE_CHECKING:
700+
from t import T
701+
702+
def f(x) -> T:
703+
from t import T
704+
705+
assert isinstance(x, T)
706+
return x
707+
""")
708+
# T is defined at runtime in one side of the if/else block
709+
self.flakes("""
710+
from typing import TYPE_CHECKING, Union
711+
712+
if TYPE_CHECKING:
713+
from t import T
714+
else:
715+
T = object
716+
717+
if not TYPE_CHECKING:
718+
U = object
719+
else:
720+
from t import U
721+
722+
def f(x) -> Union[T, U]:
723+
assert isinstance(x, (T, U))
724+
return x
725+
""")
726+
727+
@skipIf(version_info < (3,), 'new in Python 3')
728+
def test_typing_guard_import_runtime_error(self):
729+
# T and U are not bound for runtime use
730+
self.flakes("""
731+
from typing import TYPE_CHECKING, Union
732+
733+
if TYPE_CHECKING:
734+
from t import T
735+
736+
class U:
737+
pass
738+
739+
def f(x) -> Union[T, U]:
740+
assert isinstance(x, (T, U))
741+
return x
742+
""", m.UndefinedName, m.UndefinedName)
743+
693744
def test_typing_guard_for_protocol(self):
694745
self.flakes("""
695746
from typing import TYPE_CHECKING
@@ -704,6 +755,23 @@ def f(): # type: () -> int
704755
pass
705756
""")
706757

758+
def test_typing_guard_with_elif_branch(self):
759+
# This test will not raise an error even though Protocol is not
760+
# defined outside TYPE_CHECKING because Pyflakes does not do case
761+
# analysis.
762+
self.flakes("""
763+
from typing import TYPE_CHECKING
764+
if TYPE_CHECKING:
765+
from typing import Protocol
766+
elif False:
767+
Protocol = object
768+
else:
769+
pass
770+
class C(Protocol):
771+
def f(): # type: () -> int
772+
pass
773+
""")
774+
707775
def test_typednames_correct_forward_ref(self):
708776
self.flakes("""
709777
from typing import TypedDict, List, NamedTuple

0 commit comments

Comments
 (0)