Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from typing import Callable, ClassVar, Final, Optional, cast, overload
from typing import Any, Callable, ClassVar, Final, Optional, cast, overload
from typing_extensions import TypeAlias as _TypeAlias, assert_never

import mypy.checker
Expand Down Expand Up @@ -212,6 +212,10 @@
# see https://github.com/python/mypy/pull/5255#discussion_r196896335 for discussion.
MAX_UNIONS: Final = 5

# Use fallback type if literal addition of unions results in too many literal
# values. Explicitly set on the safe side to prevent accidental issues.
MAX_LITERAL_ADDITION_VALUES: Final = 15


# Types considered safe for comparisons with --strict-equality due to known behaviour of __eq__.
# NOTE: All these types are subtypes of AbstractSet.
Expand Down Expand Up @@ -3487,12 +3491,13 @@ def visit_op_expr(self, e: OpExpr) -> Type:
if isinstance(e.left, StrExpr):
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
left_type = self.accept(e.left)

right_type = self.accept(e.right)
proper_left_type = get_proper_type(left_type)
proper_right_type = get_proper_type(right_type)

if isinstance(proper_left_type, TupleType) and e.op == "+":
left_add_method = proper_left_type.partial_fallback.type.get("__add__")
if left_add_method and left_add_method.fullname == "builtins.tuple.__add__":
proper_right_type = get_proper_type(self.accept(e.right))
if isinstance(proper_right_type, TupleType):
right_radd_method = proper_right_type.partial_fallback.type.get("__radd__")
if right_radd_method is None:
Expand Down Expand Up @@ -3520,20 +3525,21 @@ def visit_op_expr(self, e: OpExpr) -> Type:
items=proper_left_type.items + [UnpackType(mapped)]
)

if e.op == "+" and (result := self.literal_expression_addition(e, left_type, right_type)):
return result

use_reverse: UseReverse = USE_REVERSE_DEFAULT
if e.op == "|":
if is_named_instance(proper_left_type, "builtins.dict"):
# This is a special case for `dict | TypedDict`.
# 1. Find `dict | TypedDict` case
# 2. Switch `dict.__or__` to `TypedDict.__ror__` (the same from both runtime and typing perspective)
proper_right_type = get_proper_type(self.accept(e.right))
if isinstance(proper_right_type, TypedDictType):
use_reverse = USE_REVERSE_ALWAYS
if isinstance(proper_left_type, TypedDictType):
# This is the reverse case: `TypedDict | dict`,
# simply do not allow the reverse checking:
# do not call `__dict__.__ror__`.
proper_right_type = get_proper_type(self.accept(e.right))
if is_named_instance(proper_right_type, "builtins.dict"):
use_reverse = USE_REVERSE_NEVER

Expand All @@ -3544,7 +3550,6 @@ def visit_op_expr(self, e: OpExpr) -> Type:
and isinstance(proper_left_type, Instance)
and proper_left_type.type.fullname == "builtins.tuple"
):
proper_right_type = get_proper_type(self.accept(e.right))
if (
isinstance(proper_right_type, TupleType)
and proper_right_type.partial_fallback.type.fullname == "builtins.tuple"
Expand All @@ -3568,7 +3573,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
result, method_type = self.check_op(
# The reverse operator here gives better error messages:
operators.reverse_op_methods[method],
base_type=self.accept(e.right),
base_type=right_type,
arg=e.left,
context=e,
allow_reverse=False,
Expand All @@ -3580,6 +3585,63 @@ def visit_op_expr(self, e: OpExpr) -> Type:
else:
raise RuntimeError(f"Unknown operator {e.op}")

def literal_value_from_expr(
self, expr: Expression, typ: Type
) -> tuple[list[Any], str, bool] | None:
if isinstance(expr, StrExpr):
return [expr.value], "builtins.str", False
if isinstance(expr, IntExpr):
return [expr.value], "builtins.int", False
if isinstance(expr, BytesExpr):
return [expr.value], "builtins.bytes", False

ptype = get_proper_type(typ)

if isinstance(ptype, LiteralType) and not isinstance(ptype.value, (bool, float)):
return [ptype.value], ptype.fallback.type.fullname, True

if isinstance(ptype, UnionType):
fallback: str | None = None
values: list[str | int] = []
for item in ptype.items:
pitem = get_proper_type(item)
if not isinstance(pitem, LiteralType) or isinstance(pitem.value, (float, bool)):
break
if fallback is None:
fallback = pitem.fallback.type.fullname
if fallback != pitem.fallback.type.fullname:
break
values.append(pitem.value)
else: # no break
assert fallback is not None
return values, fallback, True
return None

def literal_expression_addition(
self, e: OpExpr, left_type: Type, right_type: Type
) -> Type | None:
"""Check if literal values can be combined with addition."""
assert e.op == "+"
if not (lvalue := self.literal_value_from_expr(e.left, left_type)):
return None
if (
not (rvalue := self.literal_value_from_expr(e.right, right_type))
or lvalue[1] != rvalue[1] # different fallback
or lvalue[2] + rvalue[2] == 0 # no LiteralType
):
return None

values: list[int | str] = sorted(
{val[0] + val[1] for val in itertools.product(lvalue[0], rvalue[0])}
)
if len(values) == 1:
return LiteralType(values[0], self.named_type(lvalue[1]))
elif len(values) > MAX_LITERAL_ADDITION_VALUES:
return None
return make_simplified_union(
[LiteralType(val, self.named_type(lvalue[1])) for val in values]
)

def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
"""Type check a comparison expression.

Expand Down
138 changes: 130 additions & 8 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -1407,22 +1407,23 @@ c: Literal[4]
d: Literal['foo']
e: str

reveal_type(a + a) # N: Revealed type is "builtins.int"
reveal_type(a + a) # N: Revealed type is "Literal[6]"
reveal_type(a + b) # N: Revealed type is "builtins.int"
reveal_type(b + a) # N: Revealed type is "builtins.int"
reveal_type(a + 1) # N: Revealed type is "builtins.int"
reveal_type(1 + a) # N: Revealed type is "builtins.int"
reveal_type(a + c) # N: Revealed type is "builtins.int"
reveal_type(c + a) # N: Revealed type is "builtins.int"
reveal_type(a + 1) # N: Revealed type is "Literal[4]"
reveal_type(1 + a) # N: Revealed type is "Literal[4]"
reveal_type(a + c) # N: Revealed type is "Literal[7]"
reveal_type(c + a) # N: Revealed type is "Literal[7]"

reveal_type(d + d) # N: Revealed type is "builtins.str"
reveal_type(d + d) # N: Revealed type is "Literal['foofoo']"
reveal_type(d + e) # N: Revealed type is "builtins.str"
reveal_type(e + d) # N: Revealed type is "builtins.str"
reveal_type(d + 'foo') # N: Revealed type is "builtins.str"
reveal_type('foo' + d) # N: Revealed type is "builtins.str"
reveal_type(d + 'bar') # N: Revealed type is "Literal['foobar']"
reveal_type('bar' + d) # N: Revealed type is "Literal['barfoo']"

reveal_type(a.__add__(b)) # N: Revealed type is "builtins.int"
reveal_type(b.__add__(a)) # N: Revealed type is "builtins.int"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check a.__add__(a). It would be nice if it was consistent with a + a, though it's probably not important.

Copy link
Collaborator Author

@cdce8p cdce8p Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result will be int even if a: Literal[3]. The narrowing is only applied for the a + a case currently. The method call to __add__ is handled by self.check_op as the fallback case which I haven't modified here.

Test case added in 36f5e2e

reveal_type(a.__add__(a)) # N: Revealed type is "builtins.int"

a *= b # E: Incompatible types in assignment (expression has type "int", variable has type "Literal[3]")
b *= a
Expand Down Expand Up @@ -2976,3 +2977,124 @@ x: Type[Literal[1]] # E: Type[...] can't contain "Literal[...]"
y: Type[Union[Literal[1], Literal[2]]] # E: Type[...] can't contain "Union[Literal[...], Literal[...]]"
z: Type[Literal[1, 2]] # E: Type[...] can't contain "Union[Literal[...], Literal[...]]"
[builtins fixtures/tuple.pyi]

[case testLiteralAddition]
from typing import Any, Union
from typing_extensions import Literal

class A:
def __add__(self, other: str) -> str: ...
def __radd__(self, other: str) -> str: ...

str_a: Literal["a"]
str_b: Literal["b"]
str_union_1: Literal["a", "b"]
str_union_2: Literal["d", "c"]
str_union_mixed_1: Union[Literal["a"], Any]
str_union_mixed_2: Union[Literal["a"], A]
s: str
int_1: Literal[1]
int_2: Literal[2]
int_union_1: Literal[1, 2]
int_union_2: Literal[4, 3]
i: int
bytes_a: Literal[b"a"]
bytes_b: Literal[b"b"]
bytes_union_1: Literal[b"a", b"b"]
bytes_union_2: Literal[b"d", b"c"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test union with literal type and non-literal type -- e.g. user-defined class that defines __add__ that accepts str. Also test with union of literal type and Any.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The narrowing is only applied if all Union items are Literals itself. Thus the inference for these cases didn't change.

Test cases added in 36f5e2e

b: bytes

misc_union: Literal["a", 1]

reveal_type("a" + "b") # N: Revealed type is "builtins.str"
reveal_type(str_a + str_b) # N: Revealed type is "Literal['ab']"
reveal_type(str_a + "b") # N: Revealed type is "Literal['ab']"
reveal_type("a" + str_b) # N: Revealed type is "Literal['ab']"
reveal_type(str_union_1 + "b") # N: Revealed type is "Union[Literal['ab'], Literal['bb']]"
reveal_type(str_union_1 + str_b) # N: Revealed type is "Union[Literal['ab'], Literal['bb']]"
reveal_type("a" + str_union_1) # N: Revealed type is "Union[Literal['aa'], Literal['ab']]"
reveal_type(str_a + str_union_1) # N: Revealed type is "Union[Literal['aa'], Literal['ab']]"
reveal_type(str_union_1 + str_union_2) # N: Revealed type is "Union[Literal['ac'], Literal['ad'], Literal['bc'], Literal['bd']]"
reveal_type(str_a + s) # N: Revealed type is "builtins.str"
reveal_type(s + str_a) # N: Revealed type is "builtins.str"
reveal_type(str_union_1 + s) # N: Revealed type is "builtins.str"
reveal_type(s + str_union_1) # N: Revealed type is "builtins.str"
reveal_type(str_a + str_union_mixed_1) # N: Revealed type is "builtins.str"
reveal_type(str_union_mixed_1 + str_a) # N: Revealed type is "Union[builtins.str, Any]"
reveal_type(str_a + str_union_mixed_2) # N: Revealed type is "builtins.str"
reveal_type(str_union_mixed_2 + str_a) # N: Revealed type is "builtins.str"

reveal_type(1 + 2) # N: Revealed type is "builtins.int"
reveal_type(int_1 + int_2) # N: Revealed type is "Literal[3]"
reveal_type(int_1 + 1) # N: Revealed type is "Literal[2]"
reveal_type(1 + int_1) # N: Revealed type is "Literal[2]"
reveal_type(int_union_1 + 1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
reveal_type(int_union_1 + int_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
reveal_type(1 + int_union_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
reveal_type(int_1 + int_union_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
reveal_type(int_union_1 + int_union_2) # N: Revealed type is "Union[Literal[4], Literal[5], Literal[6]]"
reveal_type(int_1 + i) # N: Revealed type is "builtins.int"
reveal_type(i + int_1) # N: Revealed type is "builtins.int"
reveal_type(int_union_1 + i) # N: Revealed type is "builtins.int"
reveal_type(i + int_union_1) # N: Revealed type is "builtins.int"

reveal_type(b"a" + b"b") # N: Revealed type is "builtins.bytes"
reveal_type(bytes_a + bytes_b) # N: Revealed type is "Literal[b'ab']"
reveal_type(bytes_a + b"b") # N: Revealed type is "Literal[b'ab']"
reveal_type(b"a" + bytes_b) # N: Revealed type is "Literal[b'ab']"
reveal_type(bytes_union_1 + b"b") # N: Revealed type is "Union[Literal[b'ab'], Literal[b'bb']]"
reveal_type(bytes_union_1 + bytes_b) # N: Revealed type is "Union[Literal[b'ab'], Literal[b'bb']]"
reveal_type(b"a" + bytes_union_1) # N: Revealed type is "Union[Literal[b'aa'], Literal[b'ab']]"
reveal_type(bytes_a + bytes_union_1) # N: Revealed type is "Union[Literal[b'aa'], Literal[b'ab']]"
reveal_type(bytes_union_1 + bytes_union_2) # N: Revealed type is "Union[Literal[b'ac'], Literal[b'ad'], Literal[b'bc'], Literal[b'bd']]"
reveal_type(bytes_a + b) # N: Revealed type is "builtins.bytes"
reveal_type(b + bytes_a) # N: Revealed type is "builtins.bytes"
reveal_type(bytes_union_1 + b) # N: Revealed type is "builtins.bytes"
reveal_type(b + bytes_union_1) # N: Revealed type is "builtins.bytes"

reveal_type(misc_union + "a") # N: Revealed type is "Union[builtins.str, builtins.int]" \
# E: Unsupported operand types for + ("Literal[1]" and "str") \
# N: Left operand is of type "Literal['a', 1]"
reveal_type("a" + misc_union) # E: Unsupported operand types for + ("str" and "Literal[1]") \
# N: Right operand is of type "Literal['a', 1]" \
# N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]

[case testLiteralAdditionInheritance]
class A:
a = ""

class B(A):
a = "a" + "b"

class C:
a = "a" + "b"

reveal_type(A.a) # N: Revealed type is "builtins.str"
reveal_type(B.a) # N: Revealed type is "builtins.str"
reveal_type(C.a) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]

[case testLiteralAdditionTypedDict]
from typing import TypedDict
from typing_extensions import Literal

class LookupDict(TypedDict):
top_var: str
bottom_var: str
var: str

def func(d: LookupDict, pos: Literal["top_", "bottom_", ""]) -> str:
return d[pos + "var"]

[builtins fixtures/dict.pyi]
[typing fixtures/typing-typeddict.pyi]

[case testLiteralAdditionGuardMaxValues]
from typing_extensions import Literal

HexDigit = Literal["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F"]

def foo(a: HexDigit, b: HexDigit, c: HexDigit) -> None:
reveal_type(a + b + c) # N: Revealed type is "builtins.str"
[builtins fixtures/primitives.pyi]
1 change: 1 addition & 0 deletions test-data/unit/fixtures/primitives.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class str(Sequence[str]):
def __getitem__(self, item: int) -> str: pass
def format(self, *args: object, **kwargs: object) -> str: pass
class bytes(Sequence[int]):
def __add__(self, x: bytes) -> bytes: pass
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> int: pass
Expand Down