Skip to content

Commit 9efceb6

Browse files
committed
Use type context for for loops
1 parent e5546fe commit 9efceb6

File tree

6 files changed

+38
-19
lines changed

6 files changed

+38
-19
lines changed

mypy/checker.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5309,10 +5309,16 @@ def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]:
53095309

53105310
def visit_for_stmt(self, s: ForStmt) -> None:
53115311
"""Type check a for statement."""
5312+
lvalue_type, b, c = self.check_lvalue(s.index)
5313+
if lvalue_type is not None:
5314+
context: Type | None = self.named_generic_type("typing.Iterable", [lvalue_type])
5315+
else:
5316+
context = None
5317+
53125318
if s.is_async:
5313-
iterator_type, item_type = self.analyze_async_iterable_item_type(s.expr)
5319+
iterator_type, item_type = self.analyze_async_iterable_item_type(s.expr, context)
53145320
else:
5315-
iterator_type, item_type = self.analyze_iterable_item_type(s.expr)
5321+
iterator_type, item_type = self.analyze_iterable_item_type(s.expr, context)
53165322
s.inferred_item_type = item_type
53175323
s.inferred_iterator_type = iterator_type
53185324

@@ -5324,21 +5330,25 @@ def visit_for_stmt(self, s: ForStmt) -> None:
53245330
),
53255331
)
53265332

5327-
def analyze_async_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
5333+
def analyze_async_iterable_item_type(
5334+
self, expr: Expression, context: Type | None = None
5335+
) -> tuple[Type, Type]:
53285336
"""Analyse async iterable expression and return iterator and iterator item types."""
53295337
echk = self.expr_checker
5330-
iterable = echk.accept(expr)
5338+
iterable = echk.accept(expr, context)
53315339
iterator = echk.check_method_call_by_name("__aiter__", iterable, [], [], expr)[0]
53325340
awaitable = echk.check_method_call_by_name("__anext__", iterator, [], [], expr)[0]
53335341
item_type = echk.check_awaitable_expr(
53345342
awaitable, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_FOR
53355343
)
53365344
return iterator, item_type
53375345

5338-
def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
5346+
def analyze_iterable_item_type(
5347+
self, expr: Expression, context: Type | None = None
5348+
) -> tuple[Type, Type]:
53395349
"""Analyse iterable expression and return iterator and iterator item types."""
53405350
iterator, iterable = self.analyze_iterable_item_type_without_expression(
5341-
self.expr_checker.accept(expr), context=expr
5351+
self.expr_checker.accept(expr, context), context=expr
53425352
)
53435353
int_type = self.analyze_range_native_int_type(expr)
53445354
if int_type:

mypy/checkpattern.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
307307
narrowed_inner_types = []
308308
inner_rest_types = []
309309
for inner_type, new_inner_type in zip(inner_types, new_inner_types):
310+
# TODO: for loop type context should narrow on "assignment"?
311+
assert inner_type is not None
312+
310313
(narrowed_inner_type, inner_rest_type) = (
311314
self.chk.conditional_types_with_intersection(
312315
inner_type, [get_type_range(new_inner_type)], o, default=inner_type

mypyc/irbuild/ll_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2317,15 +2317,15 @@ def decompose_union_helper(
23172317
rest_items.append(item)
23182318
exit_block = BasicBlock()
23192319
result = Register(result_type)
2320-
for i, item in enumerate(fast_items):
2320+
for i, inst in enumerate(fast_items):
23212321
more_types = i < len(fast_items) - 1 or rest_items
23222322
if more_types:
23232323
# We are not at the final item so we need one more branch
2324-
op = self.isinstance_native(obj, item.class_ir, line)
2324+
op = self.isinstance_native(obj, inst.class_ir, line)
23252325
true_block, false_block = BasicBlock(), BasicBlock()
23262326
self.add_bool_branch(op, true_block, false_block)
23272327
self.activate_block(true_block)
2328-
coerced = self.coerce(obj, item, line)
2328+
coerced = self.coerce(obj, inst, line)
23292329
temp = process_item(coerced)
23302330
temp2 = self.coerce(temp, result_type, line)
23312331
self.add(Assign(result, temp2))

mypyc/test-data/irbuild-set.test

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -736,23 +736,20 @@ def not_precomputed() -> None:
736736
[out]
737737
def precomputed():
738738
r0 :: set
739-
r1, r2 :: object
740-
r3 :: str
741-
_ :: object
742-
r4 :: bit
739+
r1, r2, _ :: object
740+
r3 :: bit
743741
L0:
744742
r0 = frozenset({'False', 'None', 'True'})
745743
r1 = PyObject_GetIter(r0)
746744
L1:
747745
r2 = PyIter_Next(r1)
748746
if is_error(r2) goto L4 else goto L2
749747
L2:
750-
r3 = cast(str, r2)
751-
_ = r3
748+
_ = r2
752749
L3:
753750
goto L1
754751
L4:
755-
r4 = CPy_NoErrOccurred()
752+
r3 = CPy_NoErrOccurred()
756753
L5:
757754
return 1
758755
def precomputed2():

test-data/unit/check-inference-context.test

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,3 +1495,12 @@ def g(b: Optional[str]) -> None:
14951495
z: Callable[[], str] = lambda: reveal_type(b) # N: Revealed type is "builtins.str"
14961496
f2(lambda: reveal_type(b)) # N: Revealed type is "builtins.str"
14971497
lambda: reveal_type(b) # N: Revealed type is "builtins.str"
1498+
1499+
[case testInferenceForForLoops]
1500+
from typing import Literal
1501+
1502+
def func2() -> None:
1503+
b: Literal["foo", "bar", "baz"]
1504+
for b in ["foo", "bar"]:
1505+
# TODO: this should narrow to "foo" | "bar" ideally?
1506+
reveal_type(b) # N: Revealed type is "Union[Literal['foo'], Literal['bar'], Literal['baz']]"

test-data/unit/check-inference.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,11 +1149,11 @@ for x, (y, z) in [(A(), (B(), C()))]:
11491149
a = x
11501150
b = y
11511151
c = z
1152-
for xx, yy, zz in [(A(), B())]: # E: Need more than 2 values to unpack (3 expected)
1152+
for x2, y2, z2 in [(A(), B())]: # E: Need more than 2 values to unpack (3 expected)
11531153
pass
1154-
for xx, (yy, zz) in [(A(), B())]: # E: "B" object is not iterable
1154+
for x3, (y3, z3) in [(A(), B())]: # E: "B" object is not iterable
11551155
pass
1156-
for xxx, yyy in [(None, None)]:
1156+
for x4, y4 in [(None, None)]:
11571157
pass
11581158
[builtins fixtures/for.pyi]
11591159

0 commit comments

Comments
 (0)