diff --git a/mypy/checker.py b/mypy/checker.py index 9c389cccd95f..c7b092dfd30e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6512,7 +6512,7 @@ def refine_parent_types(self, expr: Expression, expr_type: Type) -> Mapping[Expr # and create function that will try replaying the same lookup # operation against arbitrary types. if isinstance(expr, MemberExpr): - parent_expr = collapse_walrus(expr.expr) + parent_expr = self._propagate_walrus_assignments(expr.expr, output) parent_type = self.lookup_type_or_none(parent_expr) member_name = expr.name @@ -6535,9 +6535,10 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None: return member_type elif isinstance(expr, IndexExpr): - parent_expr = collapse_walrus(expr.base) + parent_expr = self._propagate_walrus_assignments(expr.base, output) parent_type = self.lookup_type_or_none(parent_expr) + self._propagate_walrus_assignments(expr.index, output) index_type = self.lookup_type_or_none(expr.index) if index_type is None: return output @@ -6611,6 +6612,24 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None: expr = parent_expr expr_type = output[parent_expr] = make_simplified_union(new_parent_types) + def _propagate_walrus_assignments( + self, expr: Expression, type_map: dict[Expression, Type] + ) -> Expression: + """Add assignments from walrus expressions to inferred types. + + Only considers nested assignment exprs, does not recurse into other types. + This may be added later if necessary by implementing a dedicated visitor. + """ + if isinstance(expr, AssignmentExpr): + if isinstance(expr.value, AssignmentExpr): + self._propagate_walrus_assignments(expr.value, type_map) + assigned_type = self.lookup_type_or_none(expr.value) + parent_expr = collapse_walrus(expr) + if assigned_type is not None: + type_map[parent_expr] = assigned_type + return parent_expr + return expr + def refine_identity_comparison_expression( self, operands: list[Expression], diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 25565946158e..140774781a5a 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3979,3 +3979,95 @@ def check(mapping: Mapping[str, _T]) -> None: reveal_type(ok1) # N: Revealed type is "Union[_T`-1, builtins.str]" ok2: Union[_T, str] = mapping.get("", "") [builtins fixtures/tuple.pyi] + +[case testInferWalrusAssignmentAttrInCondition] +class Foo: + def __init__(self, value: bool) -> None: + self.value = value + +def check_and(maybe: bool) -> None: + foo = None + if maybe and (foo := Foo(True)).value: + reveal_type(foo) # N: Revealed type is "__main__.Foo" + else: + reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]" + +def check_and_nested(maybe: bool) -> None: + foo = None + bar = None + baz = None + if maybe and (foo := (bar := (baz := Foo(True)))).value: + reveal_type(foo) # N: Revealed type is "__main__.Foo" + reveal_type(bar) # N: Revealed type is "__main__.Foo" + reveal_type(baz) # N: Revealed type is "__main__.Foo" + else: + reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]" + reveal_type(bar) # N: Revealed type is "Union[__main__.Foo, None]" + reveal_type(baz) # N: Revealed type is "Union[__main__.Foo, None]" + +def check_or(maybe: bool) -> None: + foo = None + if maybe or (foo := Foo(True)).value: + reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]" + else: + reveal_type(foo) # N: Revealed type is "__main__.Foo" + +def check_or_nested(maybe: bool) -> None: + foo = None + bar = None + baz = None + if maybe and (foo := (bar := (baz := Foo(True)))).value: + reveal_type(foo) # N: Revealed type is "__main__.Foo" + reveal_type(bar) # N: Revealed type is "__main__.Foo" + reveal_type(baz) # N: Revealed type is "__main__.Foo" + else: + reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]" + reveal_type(bar) # N: Revealed type is "Union[__main__.Foo, None]" + reveal_type(baz) # N: Revealed type is "Union[__main__.Foo, None]" + +[case testInferWalrusAssignmentIndexInCondition] +def check_and(maybe: bool) -> None: + foo = None + bar = None + if maybe and (foo := [1])[(bar := 0)]: + reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" + reveal_type(bar) # N: Revealed type is "builtins.int" + else: + reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]" + +def check_and_nested(maybe: bool) -> None: + foo = None + bar = None + baz = None + if maybe and (foo := (bar := (baz := [1])))[0]: + reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" + reveal_type(bar) # N: Revealed type is "builtins.list[builtins.int]" + reveal_type(baz) # N: Revealed type is "builtins.list[builtins.int]" + else: + reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + reveal_type(bar) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + reveal_type(baz) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + +def check_or(maybe: bool) -> None: + foo = None + bar = None + if maybe or (foo := [1])[(bar := 0)]: + reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]" + else: + reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" + reveal_type(bar) # N: Revealed type is "builtins.int" + +def check_or_nested(maybe: bool) -> None: + foo = None + bar = None + baz = None + if maybe or (foo := (bar := (baz := [1])))[0]: + reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + reveal_type(bar) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + reveal_type(baz) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + else: + reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" + reveal_type(bar) # N: Revealed type is "builtins.list[builtins.int]" + reveal_type(baz) # N: Revealed type is "builtins.list[builtins.int]"