Skip to content

Commit 58506b3

Browse files
committed
track if typing.TYPE_CHECKING to warn about non runtime bindings
When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context.
1 parent e19886e commit 58506b3

File tree

2 files changed

+153
-30
lines changed

2 files changed

+153
-30
lines changed

pyflakes/checker.py

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,11 @@ class Binding:
239239
the node that this binding was last used.
240240
"""
241241

242-
def __init__(self, name, source):
242+
def __init__(self, name, source, runtime=True):
243243
self.name = name
244244
self.source = source
245245
self.used = False
246+
self.runtime = runtime
246247

247248
def __str__(self):
248249
return self.name
@@ -273,8 +274,8 @@ def redefines(self, other):
273274
class Builtin(Definition):
274275
"""A definition created for all Python builtins."""
275276

276-
def __init__(self, name):
277-
super().__init__(name, None)
277+
def __init__(self, name, runtime=True):
278+
super().__init__(name, None, runtime=runtime)
278279

279280
def __repr__(self):
280281
return '<{} object {!r} at 0x{:x}>'.format(
@@ -318,10 +319,10 @@ class Importation(Definition):
318319
@type fullName: C{str}
319320
"""
320321

321-
def __init__(self, name, source, full_name=None):
322+
def __init__(self, name, source, full_name=None, runtime=True):
322323
self.fullName = full_name or name
323324
self.redefined = []
324-
super().__init__(name, source)
325+
super().__init__(name, source, runtime=runtime)
325326

326327
def redefines(self, other):
327328
if isinstance(other, SubmoduleImportation):
@@ -366,11 +367,11 @@ class SubmoduleImportation(Importation):
366367
name is also the same, to avoid false positives.
367368
"""
368369

369-
def __init__(self, name, source):
370+
def __init__(self, name, source, runtime=True):
370371
# A dot should only appear in the name when it is a submodule import
371372
assert '.' in name and (not source or isinstance(source, ast.Import))
372373
package_name = name.split('.')[0]
373-
super().__init__(package_name, source)
374+
super().__init__(package_name, source, runtime=runtime)
374375
self.fullName = name
375376

376377
def redefines(self, other):
@@ -388,7 +389,8 @@ def source_statement(self):
388389

389390
class ImportationFrom(Importation):
390391

391-
def __init__(self, name, source, module, real_name=None):
392+
def __init__(
393+
self, name, source, module, real_name=None, runtime=True):
392394
self.module = module
393395
self.real_name = real_name or name
394396

@@ -397,7 +399,7 @@ def __init__(self, name, source, module, real_name=None):
397399
else:
398400
full_name = module + '.' + self.real_name
399401

400-
super().__init__(name, source, full_name)
402+
super().__init__(name, source, full_name, runtime=runtime)
401403

402404
def __str__(self):
403405
"""Return import full name with alias."""
@@ -417,8 +419,8 @@ def source_statement(self):
417419
class StarImportation(Importation):
418420
"""A binding created by a 'from x import *' statement."""
419421

420-
def __init__(self, name, source):
421-
super().__init__('*', source)
422+
def __init__(self, name, source, runtime=True):
423+
super().__init__('*', source, runtime=runtime)
422424
# Each star importation needs a unique name, and
423425
# may not be the module name otherwise it will be deemed imported
424426
self.name = name + '.*'
@@ -507,7 +509,7 @@ class ExportBinding(Binding):
507509
C{__all__} will not have an unused import warning reported for them.
508510
"""
509511

510-
def __init__(self, name, source, scope):
512+
def __init__(self, name, source, scope, runtime=True):
511513
if '__all__' in scope and isinstance(source, ast.AugAssign):
512514
self.names = list(scope['__all__'].names)
513515
else:
@@ -538,7 +540,7 @@ def _add_to_names(container):
538540
# If not list concatenation
539541
else:
540542
break
541-
super().__init__(name, source)
543+
super().__init__(name, source, runtime=runtime)
542544

543545

544546
class Scope(dict):
@@ -741,6 +743,7 @@ class Checker:
741743
nodeDepth = 0
742744
offset = None
743745
_in_annotation = AnnotationState.NONE
746+
_in_type_check_guard = False
744747

745748
builtIns = set(builtin_vars).union(_MAGIC_GLOBALS)
746749
_customBuiltIns = os.environ.get('PYFLAKES_BUILTINS')
@@ -1009,9 +1012,11 @@ def addBinding(self, node, value):
10091012
# then assume the rebound name is used as a global or within a loop
10101013
value.used = self.scope[value.name].used
10111014

1012-
# don't treat annotations as assignments if there is an existing value
1013-
# in scope
1014-
if value.name not in self.scope or not isinstance(value, Annotation):
1015+
# always allow the first assignment or if not already a runtime value,
1016+
# but do not shadow an existing assignment with an annotation or non
1017+
# runtime value.
1018+
if (not existing or not existing.runtime or (
1019+
not isinstance(value, Annotation) and value.runtime)):
10151020
cur_scope_pos = -1
10161021
# As per PEP 572, use scope in which outermost generator is defined
10171022
while (
@@ -1077,12 +1082,18 @@ def handleNodeLoad(self, node, parent):
10771082
self.report(messages.InvalidPrintSyntax, node)
10781083

10791084
try:
1080-
scope[name].used = (self.scope, node)
1085+
n = scope[name]
1086+
if (not n.runtime and not (
1087+
self._in_type_check_guard
1088+
or self._in_annotation)):
1089+
self.report(messages.UndefinedName, node, name)
1090+
return
1091+
1092+
n.used = (self.scope, node)
10811093

10821094
# if the name of SubImportation is same as
10831095
# alias of other Importation and the alias
10841096
# is used, SubImportation also should be marked as used.
1085-
n = scope[name]
10861097
if isinstance(n, Importation) and n._has_alias():
10871098
try:
10881099
scope[n.fullName].used = (self.scope, node)
@@ -1145,12 +1156,13 @@ def handleNodeStore(self, node):
11451156
break
11461157

11471158
parent_stmt = self.getParent(node)
1159+
runtime = not self._in_type_check_guard
11481160
if isinstance(parent_stmt, ast.AnnAssign) and parent_stmt.value is None:
11491161
binding = Annotation(name, node)
11501162
elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or (
11511163
parent_stmt != node._pyflakes_parent and
11521164
not self.isLiteralTupleUnpacking(parent_stmt)):
1153-
binding = Binding(name, node)
1165+
binding = Binding(name, node, runtime=runtime)
11541166
elif (
11551167
name == '__all__' and
11561168
isinstance(self.scope, ModuleScope) and
@@ -1159,11 +1171,12 @@ def handleNodeStore(self, node):
11591171
(ast.Assign, ast.AugAssign, ast.AnnAssign)
11601172
)
11611173
):
1162-
binding = ExportBinding(name, node._pyflakes_parent, self.scope)
1174+
binding = ExportBinding(
1175+
name, node._pyflakes_parent, self.scope, runtime=runtime)
11631176
elif isinstance(parent_stmt, ast.NamedExpr):
1164-
binding = NamedExprAssignment(name, node)
1177+
binding = NamedExprAssignment(name, node, runtime=runtime)
11651178
else:
1166-
binding = Assignment(name, node)
1179+
binding = Assignment(name, node, runtime=runtime)
11671180
self.addBinding(node, binding)
11681181

11691182
def handleNodeDelete(self, node):
@@ -1791,7 +1804,39 @@ def DICT(self, node):
17911804
def IF(self, node):
17921805
if isinstance(node.test, ast.Tuple) and node.test.elts != []:
17931806
self.report(messages.IfTuple, node)
1794-
self.handleChildren(node)
1807+
1808+
self.handleNode(node.test, node)
1809+
1810+
# check if the body/orelse should be handled specially because it is
1811+
# a if TYPE_CHECKING guard.
1812+
test = node.test
1813+
reverse = False
1814+
if isinstance(test, ast.UnaryOp) and isinstance(test.op, ast.Not):
1815+
test = test.operand
1816+
reverse = True
1817+
1818+
type_checking = _is_typing(test, 'TYPE_CHECKING', self.scopeStack)
1819+
orig = self._in_type_check_guard
1820+
1821+
# normalize body and orelse to a list
1822+
body, orelse = (
1823+
i if isinstance(i, list) else [i]
1824+
for i in (node.body, node.orelse))
1825+
1826+
# set the guard and handle the body
1827+
if type_checking and not reverse:
1828+
self._in_type_check_guard = True
1829+
1830+
for n in body:
1831+
self.handleNode(n, node)
1832+
1833+
# set the guard and handle the orelse
1834+
if type_checking:
1835+
self._in_type_check_guard = True if reverse else orig
1836+
1837+
for n in orelse:
1838+
self.handleNode(n, node)
1839+
self._in_type_check_guard = orig
17951840

17961841
IFEXP = IF
17971842

@@ -1903,7 +1948,10 @@ def FUNCTIONDEF(self, node):
19031948
for deco in node.decorator_list:
19041949
self.handleNode(deco, node)
19051950
self.LAMBDA(node)
1906-
self.addBinding(node, FunctionDefinition(node.name, node))
1951+
self.addBinding(
1952+
node,
1953+
FunctionDefinition(
1954+
node.name, node, runtime=not self._in_type_check_guard))
19071955
# doctest does not process doctest within a doctest,
19081956
# or in nested functions.
19091957
if (self.withDoctest and
@@ -1982,7 +2030,10 @@ def CLASSDEF(self, node):
19822030
self.deferFunction(lambda: self.handleDoctests(node))
19832031
for stmt in node.body:
19842032
self.handleNode(stmt, node)
1985-
self.addBinding(node, ClassDefinition(node.name, node))
2033+
self.addBinding(
2034+
node,
2035+
ClassDefinition(
2036+
node.name, node, runtime=not self._in_type_check_guard))
19862037

19872038
def AUGASSIGN(self, node):
19882039
self.handleNodeLoad(node.target, node)
@@ -2015,12 +2066,15 @@ def TUPLE(self, node):
20152066
LIST = TUPLE
20162067

20172068
def IMPORT(self, node):
2069+
runtime = not self._in_type_check_guard
20182070
for alias in node.names:
20192071
if '.' in alias.name and not alias.asname:
2020-
importation = SubmoduleImportation(alias.name, node)
2072+
importation = SubmoduleImportation(
2073+
alias.name, node, runtime=runtime)
20212074
else:
20222075
name = alias.asname or alias.name
2023-
importation = Importation(name, node, alias.name)
2076+
importation = Importation(
2077+
name, node, alias.name, runtime=runtime)
20242078
self.addBinding(node, importation)
20252079

20262080
def IMPORTFROM(self, node):
@@ -2032,6 +2086,7 @@ def IMPORTFROM(self, node):
20322086

20332087
module = ('.' * node.level) + (node.module or '')
20342088

2089+
runtime = not self._in_type_check_guard
20352090
for alias in node.names:
20362091
name = alias.asname or alias.name
20372092
if node.module == '__future__':
@@ -2049,10 +2104,10 @@ def IMPORTFROM(self, node):
20492104

20502105
self.scope.importStarred = True
20512106
self.report(messages.ImportStarUsed, node, module)
2052-
importation = StarImportation(module, node)
2107+
importation = StarImportation(module, node, runtime=runtime)
20532108
else:
2054-
importation = ImportationFrom(name, node,
2055-
module, alias.name)
2109+
importation = ImportationFrom(
2110+
name, node, module, alias.name, runtime=runtime)
20562111
self.addBinding(node, importation)
20572112

20582113
def TRY(self, node):

pyflakes/test/test_type_annotations.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,57 @@ def f() -> T:
645645
pass
646646
""")
647647

648+
@skipIf(version_info < (3,), 'new in Python 3')
649+
def test_typing_guard_import(self):
650+
# T is imported for runtime use
651+
self.flakes("""
652+
from typing import TYPE_CHECKING
653+
654+
if TYPE_CHECKING:
655+
from t import T
656+
657+
def f(x) -> T:
658+
from t import T
659+
660+
assert isinstance(x, T)
661+
return x
662+
""")
663+
# T is defined at runtime in one side of the if/else block
664+
self.flakes("""
665+
from typing import TYPE_CHECKING, Union
666+
667+
if TYPE_CHECKING:
668+
from t import T
669+
else:
670+
T = object
671+
672+
if not TYPE_CHECKING:
673+
U = object
674+
else:
675+
from t import U
676+
677+
def f(x) -> Union[T, U]:
678+
assert isinstance(x, (T, U))
679+
return x
680+
""")
681+
682+
@skipIf(version_info < (3,), 'new in Python 3')
683+
def test_typing_guard_import_runtime_error(self):
684+
# T and U are not bound for runtime use
685+
self.flakes("""
686+
from typing import TYPE_CHECKING, Union
687+
688+
if TYPE_CHECKING:
689+
from t import T
690+
691+
class U:
692+
pass
693+
694+
def f(x) -> Union[T, U]:
695+
assert isinstance(x, (T, U))
696+
return x
697+
""", m.UndefinedName, m.UndefinedName)
698+
648699
def test_typing_guard_for_protocol(self):
649700
self.flakes("""
650701
from typing import TYPE_CHECKING
@@ -659,6 +710,23 @@ def f() -> int:
659710
pass
660711
""")
661712

713+
def test_typing_guard_with_elif_branch(self):
714+
# This test will not raise an error even though Protocol is not
715+
# defined outside TYPE_CHECKING because Pyflakes does not do case
716+
# analysis.
717+
self.flakes("""
718+
from typing import TYPE_CHECKING
719+
if TYPE_CHECKING:
720+
from typing import Protocol
721+
elif False:
722+
Protocol = object
723+
else:
724+
pass
725+
class C(Protocol):
726+
def f(): # type: () -> int
727+
pass
728+
""")
729+
662730
def test_typednames_correct_forward_ref(self):
663731
self.flakes("""
664732
from typing import TypedDict, List, NamedTuple

0 commit comments

Comments
 (0)