Skip to content

Commit 83bb7fc

Browse files
committed
Support immediate and nested walruses when they are part of if/else clauses
1 parent e7405c9 commit 83bb7fc

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

mypy/checker.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6526,7 +6526,7 @@ def refine_parent_types(self, expr: Expression, expr_type: Type) -> Mapping[Expr
65266526
# and create function that will try replaying the same lookup
65276527
# operation against arbitrary types.
65286528
if isinstance(expr, MemberExpr):
6529-
parent_expr = collapse_walrus(expr.expr)
6529+
parent_expr = self._propagate_walrus_assignments(expr.expr, output)
65306530
parent_type = self.lookup_type_or_none(parent_expr)
65316531
member_name = expr.name
65326532

@@ -6549,9 +6549,10 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None:
65496549
return member_type
65506550

65516551
elif isinstance(expr, IndexExpr):
6552-
parent_expr = collapse_walrus(expr.base)
6552+
parent_expr = self._propagate_walrus_assignments(expr.base, output)
65536553
parent_type = self.lookup_type_or_none(parent_expr)
65546554

6555+
self._propagate_walrus_assignments(expr.index, output)
65556556
index_type = self.lookup_type_or_none(expr.index)
65566557
if index_type is None:
65576558
return output
@@ -6625,6 +6626,24 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None:
66256626
expr = parent_expr
66266627
expr_type = output[parent_expr] = make_simplified_union(new_parent_types)
66276628

6629+
def _propagate_walrus_assignments(
6630+
self, expr: Expression, type_map: dict[Expression, Type]
6631+
) -> Expression:
6632+
"""Add assignments from walrus expressions to inferred types.
6633+
6634+
Only considers nested assignment exprs, does not recurse into other types.
6635+
This may be added later if necessary by implementing a dedicated visitor.
6636+
"""
6637+
if isinstance(expr, AssignmentExpr):
6638+
if isinstance(expr.value, AssignmentExpr):
6639+
self._propagate_walrus_assignments(expr.value, type_map)
6640+
assigned_type = self.lookup_type_or_none(expr.value)
6641+
parent_expr = collapse_walrus(expr)
6642+
if assigned_type is not None:
6643+
type_map[parent_expr] = assigned_type
6644+
return parent_expr
6645+
return expr
6646+
66286647
def refine_identity_comparison_expression(
66296648
self,
66306649
operands: list[Expression],

test-data/unit/check-inference.test

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3963,3 +3963,95 @@ def f() -> None:
39633963

39643964
# The type below should not be Any.
39653965
reveal_type(x) # N: Revealed type is "builtins.int"
3966+
3967+
[case testInferWalrusAssignmentAttrInCondition]
3968+
class Foo:
3969+
def __init__(self, value: bool) -> None:
3970+
self.value = value
3971+
3972+
def check_and(maybe: bool) -> None:
3973+
foo = None
3974+
if maybe and (foo := Foo(True)).value:
3975+
reveal_type(foo) # N: Revealed type is "__main__.Foo"
3976+
else:
3977+
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
3978+
3979+
def check_and_nested(maybe: bool) -> None:
3980+
foo = None
3981+
bar = None
3982+
baz = None
3983+
if maybe and (foo := (bar := (baz := Foo(True)))).value:
3984+
reveal_type(foo) # N: Revealed type is "__main__.Foo"
3985+
reveal_type(bar) # N: Revealed type is "__main__.Foo"
3986+
reveal_type(baz) # N: Revealed type is "__main__.Foo"
3987+
else:
3988+
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
3989+
reveal_type(bar) # N: Revealed type is "Union[__main__.Foo, None]"
3990+
reveal_type(baz) # N: Revealed type is "Union[__main__.Foo, None]"
3991+
3992+
def check_or(maybe: bool) -> None:
3993+
foo = None
3994+
if maybe or (foo := Foo(True)).value:
3995+
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
3996+
else:
3997+
reveal_type(foo) # N: Revealed type is "__main__.Foo"
3998+
3999+
def check_or_nested(maybe: bool) -> None:
4000+
foo = None
4001+
bar = None
4002+
baz = None
4003+
if maybe and (foo := (bar := (baz := Foo(True)))).value:
4004+
reveal_type(foo) # N: Revealed type is "__main__.Foo"
4005+
reveal_type(bar) # N: Revealed type is "__main__.Foo"
4006+
reveal_type(baz) # N: Revealed type is "__main__.Foo"
4007+
else:
4008+
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
4009+
reveal_type(bar) # N: Revealed type is "Union[__main__.Foo, None]"
4010+
reveal_type(baz) # N: Revealed type is "Union[__main__.Foo, None]"
4011+
4012+
[case testInferWalrusAssignmentIndexInCondition]
4013+
def check_and(maybe: bool) -> None:
4014+
foo = None
4015+
bar = None
4016+
if maybe and (foo := [1])[bar := 0]:
4017+
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
4018+
reveal_type(bar) # N: Revealed type is "builtins.int"
4019+
else:
4020+
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4021+
reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]"
4022+
4023+
def check_and_nested(maybe: bool) -> None:
4024+
foo = None
4025+
bar = None
4026+
baz = None
4027+
if maybe and (foo := (bar := (baz := [1])))[0]:
4028+
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
4029+
reveal_type(bar) # N: Revealed type is "builtins.list[builtins.int]"
4030+
reveal_type(baz) # N: Revealed type is "builtins.list[builtins.int]"
4031+
else:
4032+
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4033+
reveal_type(bar) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4034+
reveal_type(baz) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4035+
4036+
def check_or(maybe: bool) -> None:
4037+
foo = None
4038+
bar = None
4039+
if maybe or (foo := [1])[bar := 0]:
4040+
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4041+
reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]"
4042+
else:
4043+
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
4044+
reveal_type(bar) # N: Revealed type is "builtins.int"
4045+
4046+
def check_or_nested(maybe: bool) -> None:
4047+
foo = None
4048+
bar = None
4049+
baz = None
4050+
if maybe or (foo := (bar := (baz := [1])))[0]:
4051+
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4052+
reveal_type(bar) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4053+
reveal_type(baz) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4054+
else:
4055+
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
4056+
reveal_type(bar) # N: Revealed type is "builtins.list[builtins.int]"
4057+
reveal_type(baz) # N: Revealed type is "builtins.list[builtins.int]"

0 commit comments

Comments
 (0)