Skip to content

Commit 53c3070

Browse files
committed
track if typing.TYPE_CHECKING to warn about non runtime bindings
When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context.
1 parent 2a6e36b commit 53c3070

File tree

2 files changed

+152
-28
lines changed

2 files changed

+152
-28
lines changed

pyflakes/checker.py

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,11 @@ class Binding:
262262
the node that this binding was last used.
263263
"""
264264

265-
def __init__(self, name, source):
265+
def __init__(self, name, source, runtime=True):
266266
self.name = name
267267
self.source = source
268268
self.used = False
269+
self.runtime = runtime
269270

270271
def __str__(self):
271272
return self.name
@@ -336,10 +337,10 @@ class Importation(Definition):
336337
@type fullName: C{str}
337338
"""
338339

339-
def __init__(self, name, source, full_name=None):
340+
def __init__(self, name, source, full_name=None, runtime=True):
340341
self.fullName = full_name or name
341342
self.redefined = []
342-
super().__init__(name, source)
343+
super().__init__(name, source, runtime=runtime)
343344

344345
def redefines(self, other):
345346
if isinstance(other, SubmoduleImportation):
@@ -384,11 +385,11 @@ class SubmoduleImportation(Importation):
384385
name is also the same, to avoid false positives.
385386
"""
386387

387-
def __init__(self, name, source):
388+
def __init__(self, name, source, runtime=True):
388389
# A dot should only appear in the name when it is a submodule import
389390
assert '.' in name and (not source or isinstance(source, ast.Import))
390391
package_name = name.split('.')[0]
391-
super().__init__(package_name, source)
392+
super().__init__(package_name, source, runtime=runtime)
392393
self.fullName = name
393394

394395
def redefines(self, other):
@@ -406,7 +407,8 @@ def source_statement(self):
406407

407408
class ImportationFrom(Importation):
408409

409-
def __init__(self, name, source, module, real_name=None):
410+
def __init__(
411+
self, name, source, module, real_name=None, runtime=True):
410412
self.module = module
411413
self.real_name = real_name or name
412414

@@ -415,7 +417,7 @@ def __init__(self, name, source, module, real_name=None):
415417
else:
416418
full_name = module + '.' + self.real_name
417419

418-
super().__init__(name, source, full_name)
420+
super().__init__(name, source, full_name, runtime=runtime)
419421

420422
def __str__(self):
421423
"""Return import full name with alias."""
@@ -435,8 +437,8 @@ def source_statement(self):
435437
class StarImportation(Importation):
436438
"""A binding created by a 'from x import *' statement."""
437439

438-
def __init__(self, name, source):
439-
super().__init__('*', source)
440+
def __init__(self, name, source, runtime=True):
441+
super().__init__('*', source, runtime=runtime)
440442
# Each star importation needs a unique name, and
441443
# may not be the module name otherwise it will be deemed imported
442444
self.name = name + '.*'
@@ -525,7 +527,7 @@ class ExportBinding(Binding):
525527
C{__all__} will not have an unused import warning reported for them.
526528
"""
527529

528-
def __init__(self, name, source, scope):
530+
def __init__(self, name, source, scope, runtime=True):
529531
if '__all__' in scope and isinstance(source, ast.AugAssign):
530532
self.names = list(scope['__all__'].names)
531533
else:
@@ -556,7 +558,7 @@ def _add_to_names(container):
556558
# If not list concatenation
557559
else:
558560
break
559-
super().__init__(name, source)
561+
super().__init__(name, source, runtime=runtime)
560562

561563

562564
class Scope(dict):
@@ -827,6 +829,7 @@ class Checker:
827829
offset = None
828830
_in_annotation = AnnotationState.NONE
829831
_in_deferred = False
832+
_in_type_check_guard = False
830833

831834
builtIns = set(builtin_vars).union(_MAGIC_GLOBALS)
832835
_customBuiltIns = os.environ.get('PYFLAKES_BUILTINS')
@@ -1097,9 +1100,11 @@ def addBinding(self, node, value):
10971100
# then assume the rebound name is used as a global or within a loop
10981101
value.used = self.scope[value.name].used
10991102

1100-
# don't treat annotations as assignments if there is an existing value
1101-
# in scope
1102-
if value.name not in self.scope or not isinstance(value, Annotation):
1103+
# always allow the first assignment or if not already a runtime value,
1104+
# but do not shadow an existing assignment with an annotation or non
1105+
# runtime value.
1106+
if (not existing or not existing.runtime or (
1107+
not isinstance(value, Annotation) and value.runtime)):
11031108
cur_scope_pos = -1
11041109
# As per PEP 572, use scope in which outermost generator is defined
11051110
while (
@@ -1165,12 +1170,18 @@ def handleNodeLoad(self, node):
11651170
self.report(messages.InvalidPrintSyntax, node)
11661171

11671172
try:
1168-
scope[name].used = (self.scope, node)
1173+
n = scope[name]
1174+
if (not n.runtime and not (
1175+
self._in_type_check_guard
1176+
or self._in_annotation)):
1177+
self.report(messages.UndefinedName, node, name)
1178+
return
1179+
1180+
n.used = (self.scope, node)
11691181

11701182
# if the name of SubImportation is same as
11711183
# alias of other Importation and the alias
11721184
# is used, SubImportation also should be marked as used.
1173-
n = scope[name]
11741185
if isinstance(n, Importation) and n._has_alias():
11751186
try:
11761187
scope[n.fullName].used = (self.scope, node)
@@ -1233,12 +1244,13 @@ def handleNodeStore(self, node):
12331244
break
12341245

12351246
parent_stmt = self.getParent(node)
1247+
runtime = not self._in_type_check_guard
12361248
if isinstance(parent_stmt, ast.AnnAssign) and parent_stmt.value is None:
12371249
binding = Annotation(name, node)
12381250
elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or (
12391251
parent_stmt != node._pyflakes_parent and
12401252
not self.isLiteralTupleUnpacking(parent_stmt)):
1241-
binding = Binding(name, node)
1253+
binding = Binding(name, node, runtime=runtime)
12421254
elif (
12431255
name == '__all__' and
12441256
isinstance(self.scope, ModuleScope) and
@@ -1247,11 +1259,12 @@ def handleNodeStore(self, node):
12471259
(ast.Assign, ast.AugAssign, ast.AnnAssign)
12481260
)
12491261
):
1250-
binding = ExportBinding(name, node._pyflakes_parent, self.scope)
1262+
binding = ExportBinding(
1263+
name, node._pyflakes_parent, self.scope, runtime=runtime)
12511264
elif PY38_PLUS and isinstance(parent_stmt, ast.NamedExpr):
1252-
binding = NamedExprAssignment(name, node)
1265+
binding = NamedExprAssignment(name, node, runtime=runtime)
12531266
else:
1254-
binding = Assignment(name, node)
1267+
binding = Assignment(name, node, runtime=runtime)
12551268
self.addBinding(node, binding)
12561269

12571270
def handleNodeDelete(self, node):
@@ -1912,7 +1925,40 @@ def DICT(self, node):
19121925
def IF(self, node):
19131926
if isinstance(node.test, ast.Tuple) and node.test.elts != []:
19141927
self.report(messages.IfTuple, node)
1915-
self.handleChildren(node)
1928+
1929+
self._handle_type_comments(node)
1930+
self.handleNode(node.test, node)
1931+
1932+
# check if the body/orelse should be handled specially because it is
1933+
# a if TYPE_CHECKING guard.
1934+
test = node.test
1935+
reverse = False
1936+
if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not):
1937+
test = test.operand
1938+
reverse = True
1939+
1940+
type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack)
1941+
orig = self._in_type_check_guard
1942+
1943+
# normalize body and orelse to a list
1944+
body, orelse = (
1945+
i if isinstance(i, list) else [i]
1946+
for i in (node.body, node.orelse))
1947+
1948+
# set the guard and handle the body
1949+
if type_checking and not reverse:
1950+
self._in_type_check_guard = True
1951+
1952+
for n in body:
1953+
self.handleNode(n, node)
1954+
1955+
# set the guard and handle the orelse
1956+
if type_checking:
1957+
self._in_type_check_guard = True if reverse else orig
1958+
1959+
for n in orelse:
1960+
self.handleNode(n, node)
1961+
self._in_type_check_guard = orig
19161962

19171963
IFEXP = IF
19181964

@@ -2031,7 +2077,10 @@ def FUNCTIONDEF(self, node):
20312077
for deco in node.decorator_list:
20322078
self.handleNode(deco, node)
20332079
self.LAMBDA(node)
2034-
self.addBinding(node, FunctionDefinition(node.name, node))
2080+
self.addBinding(
2081+
node,
2082+
FunctionDefinition(
2083+
node.name, node, runtime=not self._in_type_check_guard))
20352084
# doctest does not process doctest within a doctest,
20362085
# or in nested functions.
20372086
if (self.withDoctest and
@@ -2124,7 +2173,10 @@ def CLASSDEF(self, node):
21242173
for stmt in node.body:
21252174
self.handleNode(stmt, node)
21262175
self.popScope()
2127-
self.addBinding(node, ClassDefinition(node.name, node))
2176+
self.addBinding(
2177+
node,
2178+
ClassDefinition(
2179+
node.name, node, runtime=not self._in_type_check_guard))
21282180

21292181
def AUGASSIGN(self, node):
21302182
self.handleNodeLoad(node.target)
@@ -2157,12 +2209,15 @@ def TUPLE(self, node):
21572209
LIST = TUPLE
21582210

21592211
def IMPORT(self, node):
2212+
runtime = not self._in_type_check_guard
21602213
for alias in node.names:
21612214
if '.' in alias.name and not alias.asname:
2162-
importation = SubmoduleImportation(alias.name, node)
2215+
importation = SubmoduleImportation(
2216+
alias.name, node, runtime=runtime)
21632217
else:
21642218
name = alias.asname or alias.name
2165-
importation = Importation(name, node, alias.name)
2219+
importation = Importation(
2220+
name, node, alias.name, runtime=runtime)
21662221
self.addBinding(node, importation)
21672222

21682223
def IMPORTFROM(self, node):
@@ -2174,6 +2229,7 @@ def IMPORTFROM(self, node):
21742229

21752230
module = ('.' * node.level) + (node.module or '')
21762231

2232+
runtime = not self._in_type_check_guard
21772233
for alias in node.names:
21782234
name = alias.asname or alias.name
21792235
if node.module == '__future__':
@@ -2191,10 +2247,10 @@ def IMPORTFROM(self, node):
21912247

21922248
self.scope.importStarred = True
21932249
self.report(messages.ImportStarUsed, node, module)
2194-
importation = StarImportation(module, node)
2250+
importation = StarImportation(module, node, runtime=runtime)
21952251
else:
2196-
importation = ImportationFrom(name, node,
2197-
module, alias.name)
2252+
importation = ImportationFrom(
2253+
name, node, module, alias.name, runtime=runtime)
21982254
self.addBinding(node, importation)
21992255

22002256
def TRY(self, node):

pyflakes/test/test_type_annotations.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,57 @@ def f(): # type: () -> T
726726
pass
727727
""")
728728

729+
@skipIf(version_info < (3,), 'new in Python 3')
730+
def test_typing_guard_import(self):
731+
# T is imported for runtime use
732+
self.flakes("""
733+
from typing import TYPE_CHECKING
734+
735+
if TYPE_CHECKING:
736+
from t import T
737+
738+
def f(x) -> T:
739+
from t import T
740+
741+
assert isinstance(x, T)
742+
return x
743+
""")
744+
# T is defined at runtime in one side of the if/else block
745+
self.flakes("""
746+
from typing import TYPE_CHECKING, Union
747+
748+
if TYPE_CHECKING:
749+
from t import T
750+
else:
751+
T = object
752+
753+
if not TYPE_CHECKING:
754+
U = object
755+
else:
756+
from t import U
757+
758+
def f(x) -> Union[T, U]:
759+
assert isinstance(x, (T, U))
760+
return x
761+
""")
762+
763+
@skipIf(version_info < (3,), 'new in Python 3')
764+
def test_typing_guard_import_runtime_error(self):
765+
# T and U are not bound for runtime use
766+
self.flakes("""
767+
from typing import TYPE_CHECKING, Union
768+
769+
if TYPE_CHECKING:
770+
from t import T
771+
772+
class U:
773+
pass
774+
775+
def f(x) -> Union[T, U]:
776+
assert isinstance(x, (T, U))
777+
return x
778+
""", m.UndefinedName, m.UndefinedName)
779+
729780
def test_typing_guard_for_protocol(self):
730781
self.flakes("""
731782
from typing import TYPE_CHECKING
@@ -740,6 +791,23 @@ def f(): # type: () -> int
740791
pass
741792
""")
742793

794+
def test_typing_guard_with_elif_branch(self):
795+
# This test will not raise an error even though Protocol is not
796+
# defined outside TYPE_CHECKING because Pyflakes does not do case
797+
# analysis.
798+
self.flakes("""
799+
from typing import TYPE_CHECKING
800+
if TYPE_CHECKING:
801+
from typing import Protocol
802+
elif False:
803+
Protocol = object
804+
else:
805+
pass
806+
class C(Protocol):
807+
def f(): # type: () -> int
808+
pass
809+
""")
810+
743811
def test_typednames_correct_forward_ref(self):
744812
self.flakes("""
745813
from typing import TypedDict, List, NamedTuple

0 commit comments

Comments
 (0)