Skip to content

Commit 92eb83c

Browse files
refactor visit_conditional_expr
1 parent 5b7279b commit 92eb83c

File tree

4 files changed

+154
-54
lines changed

4 files changed

+154
-54
lines changed

mypy/checkexpr.py

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@
141141
get_type_vars,
142142
is_literal_type_like,
143143
make_simplified_union,
144-
simple_literal_type,
145144
true_only,
146145
try_expanding_sum_type_to_union,
147146
try_getting_str_literals,
@@ -5888,63 +5887,26 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
58885887
elif else_map is None:
58895888
self.msg.redundant_condition_in_if(True, e.cond)
58905889

5890+
if ctx is None:
5891+
# When no context is provided, compute each branch individually, and
5892+
# use the union of the results as artificial context. Important for:
5893+
# - testUnificationDict
5894+
# - testConditionalExpressionWithEmpty
5895+
ctx_if_type = self.analyze_cond_branch(
5896+
if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return
5897+
)
5898+
ctx_else_type = self.analyze_cond_branch(
5899+
else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return
5900+
)
5901+
ctx = make_simplified_union([ctx_if_type, ctx_else_type])
5902+
58915903
if_type = self.analyze_cond_branch(
58925904
if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return
58935905
)
5894-
5895-
# we want to keep the narrowest value of if_type for union'ing the branches
5896-
# however, it would be silly to pass a literal as a type context. Pass the
5897-
# underlying fallback type instead.
5898-
if_type_fallback = simple_literal_type(get_proper_type(if_type)) or if_type
5899-
5900-
# Analyze the right branch using full type context and store the type
5901-
full_context_else_type = self.analyze_cond_branch(
5906+
else_type = self.analyze_cond_branch(
59025907
else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return
59035908
)
59045909

5905-
if not mypy.checker.is_valid_inferred_type(if_type, self.chk.options):
5906-
# Analyze the right branch disregarding the left branch.
5907-
else_type = full_context_else_type
5908-
# we want to keep the narrowest value of else_type for union'ing the branches
5909-
# however, it would be silly to pass a literal as a type context. Pass the
5910-
# underlying fallback type instead.
5911-
else_type_fallback = simple_literal_type(get_proper_type(else_type)) or else_type
5912-
5913-
# If it would make a difference, re-analyze the left
5914-
# branch using the right branch's type as context.
5915-
if ctx is None or not is_equivalent(else_type_fallback, ctx):
5916-
# TODO: If it's possible that the previous analysis of
5917-
# the left branch produced errors that are avoided
5918-
# using this context, suppress those errors.
5919-
if_type = self.analyze_cond_branch(
5920-
if_map,
5921-
e.if_expr,
5922-
context=else_type_fallback,
5923-
allow_none_return=allow_none_return,
5924-
)
5925-
5926-
elif if_type_fallback == ctx:
5927-
# There is no point re-running the analysis if if_type is equal to ctx.
5928-
# That would be an exact duplicate of the work we just did.
5929-
# This optimization is particularly important to avoid exponential blowup with nested
5930-
# if/else expressions: https://github.com/python/mypy/issues/9591
5931-
# TODO: would checking for is_proper_subtype also work and cover more cases?
5932-
else_type = full_context_else_type
5933-
else:
5934-
# Analyze the right branch in the context of the left
5935-
# branch's type.
5936-
else_type = self.analyze_cond_branch(
5937-
else_map,
5938-
e.else_expr,
5939-
context=if_type_fallback,
5940-
allow_none_return=allow_none_return,
5941-
)
5942-
5943-
# In most cases using if_type as a context for right branch gives better inferred types.
5944-
# This is however not the case for literal types, so use the full context instead.
5945-
if is_literal_type_like(full_context_else_type) and not is_literal_type_like(else_type):
5946-
else_type = full_context_else_type
5947-
59485910
res: Type = make_simplified_union([if_type, else_type])
59495911
if has_uninhabited_component(res) and not isinstance(
59505912
get_proper_type(self.type_context[-1]), UnionType

mypyc/irbuild/statement.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,9 @@ def make_entry(type: Expression) -> tuple[ValueGenFunc, int]:
599599
(make_entry(type) if type else None, var, make_handler(body))
600600
for type, var, body in zip(t.types, t.vars, t.handlers)
601601
]
602-
else_body = (lambda: builder.accept(t.else_body)) if t.else_body else None
602+
603+
_else_body = t.else_body
604+
else_body = (lambda: builder.accept(_else_body)) if _else_body else None
603605
transform_try_except(builder, body, handlers, else_body, t.line)
604606

605607

test-data/unit/check-literal.test

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2946,6 +2946,140 @@ reveal_type(C().collection) # N: Revealed type is "builtins.list[Literal['word'
29462946
reveal_type(C().word) # N: Revealed type is "Literal['word']"
29472947
[builtins fixtures/tuple.pyi]
29482948

2949+
[case testStringLiteralTernary]
2950+
def test(b: bool) -> None:
2951+
l = "foo" if b else "bar"
2952+
reveal_type(l) # N: Revealed type is "builtins.str"
2953+
[builtins fixtures/tuple.pyi]
2954+
2955+
[case testintLiteralTernary]
2956+
def test(b: bool) -> None:
2957+
l = 0 if b else 1
2958+
reveal_type(l) # N: Revealed type is "builtins.int"
2959+
[builtins fixtures/tuple.pyi]
2960+
2961+
[case testStringIntUnionTernary]
2962+
def test(b: bool) -> None:
2963+
l = 1 if b else "a"
2964+
reveal_type(l) # N: Revealed type is "Union[builtins.int, builtins.str]"
2965+
[builtins fixtures/tuple.pyi]
2966+
2967+
[case testListComprehensionTernary]
2968+
# gh-19534
2969+
def test(b: bool) -> None:
2970+
l = [1] if b else ["a"]
2971+
reveal_type(l) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]"
2972+
[builtins fixtures/list.pyi]
2973+
2974+
[case testSetComprehensionTernary]
2975+
# gh-19534
2976+
def test(b: bool) -> None:
2977+
s = {1} if b else {"a"}
2978+
reveal_type(s) # N: Revealed type is "Union[builtins.set[builtins.int], builtins.set[builtins.str]]"
2979+
[builtins fixtures/set.pyi]
2980+
2981+
[case testDictComprehensionTernary]
2982+
# gh-19534
2983+
def test(b: bool) -> None:
2984+
d = {1:1} if "" else {"a": "a"}
2985+
reveal_type(d) # N: Revealed type is "Union[builtins.dict[builtins.int, builtins.int], builtins.dict[builtins.str, builtins.str]]"
2986+
[builtins fixtures/dict.pyi]
2987+
2988+
[case testLambdaTernary]
2989+
from typing import TypeVar, Union, Callable, reveal_type
2990+
2991+
NOOP = lambda: None
2992+
class A: pass
2993+
class B:
2994+
attr: Union[A, None]
2995+
2996+
def test_static(x: Union[A, None]) -> None:
2997+
def foo(t: A) -> None: ...
2998+
2999+
l1: Callable[[], object] = (lambda: foo(x)) if x is not None else NOOP
3000+
r1: Callable[[], object] = NOOP if x is None else (lambda: foo(x))
3001+
l2 = (lambda: foo(x)) if x is not None else NOOP
3002+
r2 = NOOP if x is None else (lambda: foo(x))
3003+
reveal_type(l2) # N: Revealed type is "def ()"
3004+
reveal_type(r2) # N: Revealed type is "def ()"
3005+
3006+
def test_generic(x: Union[A, None]) -> None:
3007+
T = TypeVar("T")
3008+
def bar(t: T) -> T: return t
3009+
3010+
l1: Callable[[], None] = (lambda: bar(x)) if x is None else NOOP
3011+
r1: Callable[[], None] = NOOP if x is not None else (lambda: bar(x))
3012+
l2 = (lambda: bar(x)) if x is None else NOOP
3013+
r2 = NOOP if x is not None else (lambda: bar(x))
3014+
reveal_type(l2) # N: Revealed type is "def ()"
3015+
reveal_type(r2) # N: Revealed type is "def ()"
3016+
3017+
3018+
[case testLambdaTernaryIndirectAttribute]
3019+
# fails due to binder issue inside `check_func_def`
3020+
# gh-19561
3021+
from typing import TypeVar, Union, Callable, reveal_type
3022+
3023+
NOOP = lambda: None
3024+
class A: pass
3025+
class B:
3026+
attr: Union[A, None]
3027+
3028+
def test_static_with_attr(x: B) -> None:
3029+
def foo(t: A) -> None: ...
3030+
3031+
l1: Callable[[], None] = (lambda: foo(x.attr)) if x.attr is not None else NOOP # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3032+
r1: Callable[[], None] = NOOP if x.attr is None else (lambda: foo(x.attr)) # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3033+
l2 = (lambda: foo(x.attr)) if x.attr is not None else NOOP # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3034+
r2 = NOOP if x.attr is None else (lambda: foo(x.attr)) # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3035+
reveal_type(l2) # N: Revealed type is "def ()"
3036+
reveal_type(r2) # N: Revealed type is "def ()"
3037+
3038+
def test_generic_with_attr(x: B) -> None:
3039+
T = TypeVar("T")
3040+
def bar(t: T) -> T: return t
3041+
3042+
l1: Callable[[], None] = (lambda: bar(x.attr)) if x.attr is None else NOOP # E: Incompatible types in assignment (expression has type "Callable[[], Optional[A]]", variable has type "Callable[[], None]")
3043+
r1: Callable[[], None] = NOOP if x.attr is not None else (lambda: bar(x.attr)) # E: Incompatible types in assignment (expression has type "Callable[[], Optional[A]]", variable has type "Callable[[], None]")
3044+
l2 = (lambda: bar(x.attr)) if x.attr is None else NOOP
3045+
r2 = NOOP if x.attr is not None else (lambda: bar(x.attr))
3046+
reveal_type(l2) # N: Revealed type is "def () -> Union[__main__.A, None]"
3047+
reveal_type(r2) # N: Revealed type is "def () -> Union[__main__.A, None]"
3048+
3049+
[case testLambdaTernaryDoubleIndirectAttribute]
3050+
# fails due to binder issue inside `check_func_def`
3051+
# gh-19561
3052+
from typing import TypeVar, Union, Callable, reveal_type
3053+
3054+
NOOP = lambda: None
3055+
class A: pass
3056+
class B:
3057+
attr: Union[A, None]
3058+
class C:
3059+
attr: B
3060+
3061+
def test_static_with_attr(x: C) -> None:
3062+
def foo(t: A) -> None: ...
3063+
3064+
l1: Callable[[], None] = (lambda: foo(x.attr.attr)) if x.attr.attr is not None else NOOP # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3065+
r1: Callable[[], None] = NOOP if x.attr.attr is None else (lambda: foo(x.attr.attr)) # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3066+
l2 = (lambda: foo(x.attr.attr)) if x.attr.attr is not None else NOOP # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3067+
r2 = NOOP if x.attr.attr is None else (lambda: foo(x.attr.attr)) # E: Argument 1 to "foo" has incompatible type "Optional[A]"; expected "A"
3068+
reveal_type(l2) # N: Revealed type is "def ()"
3069+
reveal_type(r2) # N: Revealed type is "def ()"
3070+
3071+
def test_generic_with_attr(x: C) -> None:
3072+
T = TypeVar("T")
3073+
def bar(t: T) -> T: return t
3074+
3075+
l1: Callable[[], None] = (lambda: bar(x.attr.attr)) if x.attr.attr is None else NOOP # E: Incompatible types in assignment (expression has type "Callable[[], Optional[A]]", variable has type "Callable[[], None]")
3076+
r1: Callable[[], None] = NOOP if x.attr.attr is not None else (lambda: bar(x.attr.attr)) # E: Incompatible types in assignment (expression has type "Callable[[], Optional[A]]", variable has type "Callable[[], None]")
3077+
l2 = (lambda: bar(x.attr.attr)) if x.attr.attr is None else NOOP
3078+
r2 = NOOP if x.attr.attr is not None else (lambda: bar(x.attr.attr))
3079+
reveal_type(l2) # N: Revealed type is "def () -> Union[__main__.A, None]"
3080+
reveal_type(r2) # N: Revealed type is "def () -> Union[__main__.A, None]"
3081+
3082+
29493083
[case testLiteralTernaryUnionNarrowing]
29503084
from typing import Literal, Optional
29513085

test-data/unit/check-optional.test

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,9 @@ reveal_type(l) # N: Revealed type is "builtins.list[typing.Generator[builtins.s
427427
[builtins fixtures/list.pyi]
428428

429429
[case testNoneListTernary]
430-
x = [None] if "" else [1] # E: List item 0 has incompatible type "int"; expected "None"
430+
# gh-19534
431+
x = [None] if "" else [1]
432+
reveal_type(x) # N: Revealed type is "Union[builtins.list[None], builtins.list[builtins.int]]"
431433
[builtins fixtures/list.pyi]
432434

433435
[case testListIncompatibleErrorMessage]

0 commit comments

Comments
 (0)