Skip to content

Commit 3bbd41a

Browse files
committed
track if typing.TYPE_CHECKING to warn about non runtime bindings
When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context.
1 parent 78d8859 commit 3bbd41a

File tree

2 files changed

+155
-29
lines changed

2 files changed

+155
-29
lines changed

pyflakes/checker.py

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,11 @@ class Binding(object):
322322
the node that this binding was last used.
323323
"""
324324

325-
def __init__(self, name, source):
325+
def __init__(self, name, source, runtime=True):
326326
self.name = name
327327
self.source = source
328328
self.used = False
329+
self.runtime = runtime
329330

330331
def __str__(self):
331332
return self.name
@@ -392,10 +393,10 @@ class Importation(Definition):
392393
@type fullName: C{str}
393394
"""
394395

395-
def __init__(self, name, source, full_name=None):
396+
def __init__(self, name, source, full_name=None, runtime=True):
396397
self.fullName = full_name or name
397398
self.redefined = []
398-
super(Importation, self).__init__(name, source)
399+
super(Importation, self).__init__(name, source, runtime=runtime)
399400

400401
def redefines(self, other):
401402
if isinstance(other, SubmoduleImportation):
@@ -440,11 +441,12 @@ class SubmoduleImportation(Importation):
440441
name is also the same, to avoid false positives.
441442
"""
442443

443-
def __init__(self, name, source):
444+
def __init__(self, name, source, runtime=True):
444445
# A dot should only appear in the name when it is a submodule import
445446
assert '.' in name and (not source or isinstance(source, ast.Import))
446447
package_name = name.split('.')[0]
447-
super(SubmoduleImportation, self).__init__(package_name, source)
448+
super(SubmoduleImportation, self).__init__(
449+
package_name, source, runtime=runtime)
448450
self.fullName = name
449451

450452
def redefines(self, other):
@@ -462,7 +464,8 @@ def source_statement(self):
462464

463465
class ImportationFrom(Importation):
464466

465-
def __init__(self, name, source, module, real_name=None):
467+
def __init__(
468+
self, name, source, module, real_name=None, runtime=True):
466469
self.module = module
467470
self.real_name = real_name or name
468471

@@ -471,7 +474,8 @@ def __init__(self, name, source, module, real_name=None):
471474
else:
472475
full_name = module + '.' + self.real_name
473476

474-
super(ImportationFrom, self).__init__(name, source, full_name)
477+
super(ImportationFrom, self).__init__(
478+
name, source, full_name, runtime=runtime)
475479

476480
def __str__(self):
477481
"""Return import full name with alias."""
@@ -493,8 +497,8 @@ def source_statement(self):
493497
class StarImportation(Importation):
494498
"""A binding created by a 'from x import *' statement."""
495499

496-
def __init__(self, name, source):
497-
super(StarImportation, self).__init__('*', source)
500+
def __init__(self, name, source, runtime=True):
501+
super(StarImportation, self).__init__('*', source, runtime=runtime)
498502
# Each star importation needs a unique name, and
499503
# may not be the module name otherwise it will be deemed imported
500504
self.name = name + '.*'
@@ -577,7 +581,7 @@ class ExportBinding(Binding):
577581
C{__all__} will not have an unused import warning reported for them.
578582
"""
579583

580-
def __init__(self, name, source, scope):
584+
def __init__(self, name, source, scope, runtime=True):
581585
if '__all__' in scope and isinstance(source, ast.AugAssign):
582586
self.names = list(scope['__all__'].names)
583587
else:
@@ -608,7 +612,7 @@ def _add_to_names(container):
608612
# If not list concatenation
609613
else:
610614
break
611-
super(ExportBinding, self).__init__(name, source)
615+
super(ExportBinding, self).__init__(name, source, runtime=runtime)
612616

613617

614618
class Scope(dict):
@@ -883,6 +887,7 @@ class Checker(object):
883887
offset = None
884888
_in_annotation = AnnotationState.NONE
885889
_in_deferred = False
890+
_in_type_check_guard = False
886891

887892
builtIns = set(builtin_vars).union(_MAGIC_GLOBALS)
888893
_customBuiltIns = os.environ.get('PYFLAKES_BUILTINS')
@@ -1156,9 +1161,11 @@ def addBinding(self, node, value):
11561161
# then assume the rebound name is used as a global or within a loop
11571162
value.used = self.scope[value.name].used
11581163

1159-
# don't treat annotations as assignments if there is an existing value
1160-
# in scope
1161-
if value.name not in self.scope or not isinstance(value, Annotation):
1164+
# always allow the first assignment or if not already a runtime value,
1165+
# but do not shadow an existing assignment with an annotation or non
1166+
# runtime value.
1167+
if (not existing or not existing.runtime or (
1168+
not isinstance(value, Annotation) and value.runtime)):
11621169
self.scope[value.name] = value
11631170

11641171
def _unknown_handler(self, node):
@@ -1217,12 +1224,18 @@ def handleNodeLoad(self, node):
12171224
self.report(messages.InvalidPrintSyntax, node)
12181225

12191226
try:
1220-
scope[name].used = (self.scope, node)
1227+
n = scope[name]
1228+
if (not n.runtime and not (
1229+
self._in_type_check_guard
1230+
or self._in_annotation)):
1231+
self.report(messages.UndefinedName, node, name)
1232+
return
1233+
1234+
n.used = (self.scope, node)
12211235

12221236
# if the name of SubImportation is same as
12231237
# alias of other Importation and the alias
12241238
# is used, SubImportation also should be marked as used.
1225-
n = scope[name]
12261239
if isinstance(n, Importation) and n._has_alias():
12271240
try:
12281241
scope[n.fullName].used = (self.scope, node)
@@ -1285,18 +1298,20 @@ def handleNodeStore(self, node):
12851298
break
12861299

12871300
parent_stmt = self.getParent(node)
1301+
runtime = not self._in_type_check_guard
12881302
if isinstance(parent_stmt, ANNASSIGN_TYPES) and parent_stmt.value is None:
1289-
binding = Annotation(name, node)
1303+
binding = Annotation(name, node, runtime=runtime)
12901304
elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or (
12911305
parent_stmt != node._pyflakes_parent and
12921306
not self.isLiteralTupleUnpacking(parent_stmt)):
1293-
binding = Binding(name, node)
1307+
binding = Binding(name, node, runtime=runtime)
12941308
elif name == '__all__' and isinstance(self.scope, ModuleScope):
1295-
binding = ExportBinding(name, node._pyflakes_parent, self.scope)
1309+
binding = ExportBinding(
1310+
name, node._pyflakes_parent, self.scope, runtime=runtime)
12961311
elif PY2 and isinstance(getattr(node, 'ctx', None), ast.Param):
1297-
binding = Argument(name, self.getScopeNode(node))
1312+
binding = Argument(name, self.getScopeNode(node), runtime=runtime)
12981313
else:
1299-
binding = Assignment(name, node)
1314+
binding = Assignment(name, node, runtime=runtime)
13001315
self.addBinding(node, binding)
13011316

13021317
def handleNodeDelete(self, node):
@@ -1981,7 +1996,40 @@ def DICT(self, node):
19811996
def IF(self, node):
19821997
if isinstance(node.test, ast.Tuple) and node.test.elts != []:
19831998
self.report(messages.IfTuple, node)
1984-
self.handleChildren(node)
1999+
2000+
self._handle_type_comments(node)
2001+
self.handleNode(node.test, node)
2002+
2003+
# check if the body/orelse should be handled specially because it is
2004+
# a if TYPE_CHECKING guard.
2005+
test = node.test
2006+
reverse = False
2007+
if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not):
2008+
test = test.operand
2009+
reverse = True
2010+
2011+
type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack)
2012+
orig = self._in_type_check_guard
2013+
2014+
# normalize body and orelse to a list
2015+
body, orelse = (
2016+
i if isinstance(i, list) else [i]
2017+
for i in (node.body, node.orelse))
2018+
2019+
# set the guard and handle the body
2020+
if type_checking and not reverse:
2021+
self._in_type_check_guard = True
2022+
2023+
for n in body:
2024+
self.handleNode(n, node)
2025+
2026+
# set the guard and handle the orelse
2027+
if type_checking:
2028+
self._in_type_check_guard = True if reverse else orig
2029+
2030+
for n in orelse:
2031+
self.handleNode(n, node)
2032+
self._in_type_check_guard = orig
19852033

19862034
IFEXP = IF
19872035

@@ -2104,7 +2152,10 @@ def FUNCTIONDEF(self, node):
21042152
for deco in node.decorator_list:
21052153
self.handleNode(deco, node)
21062154
self.LAMBDA(node)
2107-
self.addBinding(node, FunctionDefinition(node.name, node))
2155+
self.addBinding(
2156+
node,
2157+
FunctionDefinition(
2158+
node.name, node, runtime=not self._in_type_check_guard))
21082159
# doctest does not process doctest within a doctest,
21092160
# or in nested functions.
21102161
if (self.withDoctest and
@@ -2229,7 +2280,10 @@ def CLASSDEF(self, node):
22292280
for stmt in node.body:
22302281
self.handleNode(stmt, node)
22312282
self.popScope()
2232-
self.addBinding(node, ClassDefinition(node.name, node))
2283+
self.addBinding(
2284+
node,
2285+
ClassDefinition(
2286+
node.name, node, runtime=not self._in_type_check_guard))
22332287

22342288
def AUGASSIGN(self, node):
22352289
self.handleNodeLoad(node.target)
@@ -2262,12 +2316,15 @@ def TUPLE(self, node):
22622316
LIST = TUPLE
22632317

22642318
def IMPORT(self, node):
2319+
runtime = not self._in_type_check_guard
22652320
for alias in node.names:
22662321
if '.' in alias.name and not alias.asname:
2267-
importation = SubmoduleImportation(alias.name, node)
2322+
importation = SubmoduleImportation(
2323+
alias.name, node, runtime=runtime)
22682324
else:
22692325
name = alias.asname or alias.name
2270-
importation = Importation(name, node, alias.name)
2326+
importation = Importation(
2327+
name, node, alias.name, runtime=runtime)
22712328
self.addBinding(node, importation)
22722329

22732330
def IMPORTFROM(self, node):
@@ -2280,6 +2337,7 @@ def IMPORTFROM(self, node):
22802337

22812338
module = ('.' * node.level) + (node.module or '')
22822339

2340+
runtime = not self._in_type_check_guard
22832341
for alias in node.names:
22842342
name = alias.asname or alias.name
22852343
if node.module == '__future__':
@@ -2298,10 +2356,10 @@ def IMPORTFROM(self, node):
22982356

22992357
self.scope.importStarred = True
23002358
self.report(messages.ImportStarUsed, node, module)
2301-
importation = StarImportation(module, node)
2359+
importation = StarImportation(module, node, runtime=runtime)
23022360
else:
2303-
importation = ImportationFrom(name, node,
2304-
module, alias.name)
2361+
importation = ImportationFrom(
2362+
name, node, module, alias.name, runtime=runtime)
23052363
self.addBinding(node, importation)
23062364

23072365
def TRY(self, node):

pyflakes/test/test_type_annotations.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,57 @@ def f(): # type: () -> T
707707
pass
708708
""")
709709

710+
@skipIf(version_info < (3,), 'new in Python 3')
711+
def test_typing_guard_import(self):
712+
# T is imported for runtime use
713+
self.flakes("""
714+
from typing import TYPE_CHECKING
715+
716+
if TYPE_CHECKING:
717+
from t import T
718+
719+
def f(x) -> T:
720+
from t import T
721+
722+
assert isinstance(x, T)
723+
return x
724+
""")
725+
# T is defined at runtime in one side of the if/else block
726+
self.flakes("""
727+
from typing import TYPE_CHECKING, Union
728+
729+
if TYPE_CHECKING:
730+
from t import T
731+
else:
732+
T = object
733+
734+
if not TYPE_CHECKING:
735+
U = object
736+
else:
737+
from t import U
738+
739+
def f(x) -> Union[T, U]:
740+
assert isinstance(x, (T, U))
741+
return x
742+
""")
743+
744+
@skipIf(version_info < (3,), 'new in Python 3')
745+
def test_typing_guard_import_runtime_error(self):
746+
# T and U are not bound for runtime use
747+
self.flakes("""
748+
from typing import TYPE_CHECKING, Union
749+
750+
if TYPE_CHECKING:
751+
from t import T
752+
753+
class U:
754+
pass
755+
756+
def f(x) -> Union[T, U]:
757+
assert isinstance(x, (T, U))
758+
return x
759+
""", m.UndefinedName, m.UndefinedName)
760+
710761
def test_typing_guard_for_protocol(self):
711762
self.flakes("""
712763
from typing import TYPE_CHECKING
@@ -721,6 +772,23 @@ def f(): # type: () -> int
721772
pass
722773
""")
723774

775+
def test_typing_guard_with_elif_branch(self):
776+
# This test will not raise an error even though Protocol is not
777+
# defined outside TYPE_CHECKING because Pyflakes does not do case
778+
# analysis.
779+
self.flakes("""
780+
from typing import TYPE_CHECKING
781+
if TYPE_CHECKING:
782+
from typing import Protocol
783+
elif False:
784+
Protocol = object
785+
else:
786+
pass
787+
class C(Protocol):
788+
def f(): # type: () -> int
789+
pass
790+
""")
791+
724792
def test_typednames_correct_forward_ref(self):
725793
self.flakes("""
726794
from typing import TypedDict, List, NamedTuple

0 commit comments

Comments
 (0)