diff --git a/mypy/checker.py b/mypy/checker.py index 4b3d6c3298b4..06087aa21d83 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3171,6 +3171,14 @@ def check_assignment( # Don't use type binder for definitions of special forms, like named tuples. if not (isinstance(lvalue, NameExpr) and lvalue.is_special_form): self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False) + if ( + isinstance(lvalue, NameExpr) + and isinstance(lvalue.node, Var) + and lvalue.node.is_inferred + and lvalue.node.is_index_var + and lvalue_type is not None + ): + lvalue.node.type = remove_instance_last_known_values(lvalue_type) elif index_lvalue: self.check_indexed_assignment(index_lvalue, rvalue, lvalue) @@ -3180,6 +3188,7 @@ def check_assignment( rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context) if not ( inferred.is_final + or inferred.is_index_var or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__") ): rvalue_type = remove_instance_last_known_values(rvalue_type) diff --git a/mypy/nodes.py b/mypy/nodes.py index 4806a7b0be7d..7b620bd52008 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -969,6 +969,7 @@ def is_dynamic(self) -> bool: "is_classvar", "is_abstract_var", "is_final", + "is_index_var", "final_unset_in_class", "final_set_in_init", "explicit_self_type", @@ -1005,6 +1006,7 @@ class Var(SymbolNode): "is_classvar", "is_abstract_var", "is_final", + "is_index_var", "final_unset_in_class", "final_set_in_init", "is_suppressed_import", @@ -1039,6 +1041,7 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None: self.is_settable_property = False self.is_classvar = False self.is_abstract_var = False + self.is_index_var = False # Set to true when this variable refers to a module we were unable to # parse for some reason (eg a silenced module) self.is_suppressed_import = False diff --git a/mypy/semanal.py b/mypy/semanal.py index a0d4daffca0c..e182322cfbe6 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -4225,6 +4225,7 @@ def analyze_lvalue( is_final: bool = False, escape_comprehensions: bool = False, has_explicit_value: bool = False, + is_index_var: bool = False, ) -> None: """Analyze an lvalue or assignment target. @@ -4235,6 +4236,7 @@ def analyze_lvalue( escape_comprehensions: If we are inside a comprehension, set the variable in the enclosing scope instead. This implements https://www.python.org/dev/peps/pep-0572/#scope-of-the-target + is_index_var: If lval is the index variable in a for loop """ if escape_comprehensions: assert isinstance(lval, NameExpr), "assignment expression target must be NameExpr" @@ -4245,6 +4247,7 @@ def analyze_lvalue( is_final, escape_comprehensions, has_explicit_value=has_explicit_value, + is_index_var=is_index_var, ) elif isinstance(lval, MemberExpr): self.analyze_member_lvalue(lval, explicit_type, is_final, has_explicit_value) @@ -4271,6 +4274,7 @@ def analyze_name_lvalue( is_final: bool, escape_comprehensions: bool, has_explicit_value: bool, + is_index_var: bool, ) -> None: """Analyze an lvalue that targets a name expression. @@ -4309,7 +4313,9 @@ def analyze_name_lvalue( if (not existing or isinstance(existing.node, PlaceholderNode)) and not outer: # Define new variable. - var = self.make_name_lvalue_var(lvalue, kind, not explicit_type, has_explicit_value) + var = self.make_name_lvalue_var( + lvalue, kind, not explicit_type, has_explicit_value, is_index_var + ) added = self.add_symbol(name, var, lvalue, escape_comprehensions=escape_comprehensions) # Only bind expression if we successfully added name to symbol table. if added: @@ -4361,7 +4367,12 @@ def is_alias_for_final_name(self, name: str) -> bool: return existing is not None and is_final_node(existing.node) def make_name_lvalue_var( - self, lvalue: NameExpr, kind: int, inferred: bool, has_explicit_value: bool + self, + lvalue: NameExpr, + kind: int, + inferred: bool, + has_explicit_value: bool, + is_index_var: bool, ) -> Var: """Return a Var node for an lvalue that is a name expression.""" name = lvalue.name @@ -4380,6 +4391,7 @@ def make_name_lvalue_var( v._fullname = name v.is_ready = False # Type not inferred yet v.has_explicit_value = has_explicit_value + v.is_index_var = is_index_var return v def make_name_lvalue_point_to_existing_def( @@ -5290,7 +5302,7 @@ def visit_for_stmt(self, s: ForStmt) -> None: s.expr.accept(self) # Bind index variables and check if they define new names. - self.analyze_lvalue(s.index, explicit_type=s.index_type is not None) + self.analyze_lvalue(s.index, explicit_type=s.index_type is not None, is_index_var=True) if s.index_type: if self.is_classvar(s.index_type): self.fail_invalid_classvar(s.index) diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 0dbefbc774a3..b0ec590b54f9 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -1238,6 +1238,35 @@ class B: pass [builtins fixtures/for.pyi] [out] +[case testForStatementIndexNarrowing] +from typing_extensions import TypedDict + +class X(TypedDict): + hourly: int + daily: int + +x: X +for a in ("hourly", "daily"): + reveal_type(a) # N: Revealed type is "Union[Literal['hourly']?, Literal['daily']?]" + reveal_type(x[a]) # N: Revealed type is "builtins.int" + reveal_type(a.upper()) # N: Revealed type is "builtins.str" + c = a + reveal_type(c) # N: Revealed type is "builtins.str" + a = "monthly" + reveal_type(a) # N: Revealed type is "builtins.str" + a = "yearly" + reveal_type(a) # N: Revealed type is "builtins.str" + a = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "str") + reveal_type(a) # N: Revealed type is "builtins.str" + d = a + reveal_type(d) # N: Revealed type is "builtins.str" + +b: str +for b in ("hourly", "daily"): + reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(b.upper()) # N: Revealed type is "builtins.str" +[builtins fixtures/for.pyi] + -- Regression tests -- ---------------- diff --git a/test-data/unit/fixtures/for.pyi b/test-data/unit/fixtures/for.pyi index 694f83e940b2..10f45e68cd7d 100644 --- a/test-data/unit/fixtures/for.pyi +++ b/test-data/unit/fixtures/for.pyi @@ -12,9 +12,11 @@ class type: pass class tuple(Generic[t]): def __iter__(self) -> Iterator[t]: pass class function: pass +class ellipsis: pass class bool: pass class int: pass # for convenience -class str: pass # for convenience +class str: # for convenience + def upper(self) -> str: ... class list(Iterable[t], Generic[t]): def __iter__(self) -> Iterator[t]: pass