diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ec64669c1cd0..946b513ad12e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -8,7 +8,7 @@ from collections import defaultdict from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager -from typing import Callable, ClassVar, Final, Optional, cast, overload +from typing import Any, Callable, ClassVar, Final, Optional, cast, overload from typing_extensions import TypeAlias as _TypeAlias, assert_never import mypy.checker @@ -212,6 +212,10 @@ # see https://github.com/python/mypy/pull/5255#discussion_r196896335 for discussion. MAX_UNIONS: Final = 5 +# Use fallback type if literal addition of unions results in too many literal +# values. Explicitly set on the safe side to prevent accidental issues. +MAX_LITERAL_ADDITION_VALUES: Final = 15 + # Types considered safe for comparisons with --strict-equality due to known behaviour of __eq__. # NOTE: All these types are subtypes of AbstractSet. @@ -3487,12 +3491,13 @@ def visit_op_expr(self, e: OpExpr) -> Type: if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) left_type = self.accept(e.left) - + right_type = self.accept(e.right) proper_left_type = get_proper_type(left_type) + proper_right_type = get_proper_type(right_type) + if isinstance(proper_left_type, TupleType) and e.op == "+": left_add_method = proper_left_type.partial_fallback.type.get("__add__") if left_add_method and left_add_method.fullname == "builtins.tuple.__add__": - proper_right_type = get_proper_type(self.accept(e.right)) if isinstance(proper_right_type, TupleType): right_radd_method = proper_right_type.partial_fallback.type.get("__radd__") if right_radd_method is None: @@ -3520,20 +3525,21 @@ def visit_op_expr(self, e: OpExpr) -> Type: items=proper_left_type.items + [UnpackType(mapped)] ) + if e.op == "+" and (result := self.literal_expression_addition(e, left_type, right_type)): + return result + use_reverse: UseReverse = USE_REVERSE_DEFAULT if e.op == "|": if is_named_instance(proper_left_type, "builtins.dict"): # This is a special case for `dict | TypedDict`. # 1. Find `dict | TypedDict` case # 2. Switch `dict.__or__` to `TypedDict.__ror__` (the same from both runtime and typing perspective) - proper_right_type = get_proper_type(self.accept(e.right)) if isinstance(proper_right_type, TypedDictType): use_reverse = USE_REVERSE_ALWAYS if isinstance(proper_left_type, TypedDictType): # This is the reverse case: `TypedDict | dict`, # simply do not allow the reverse checking: # do not call `__dict__.__ror__`. - proper_right_type = get_proper_type(self.accept(e.right)) if is_named_instance(proper_right_type, "builtins.dict"): use_reverse = USE_REVERSE_NEVER @@ -3544,7 +3550,6 @@ def visit_op_expr(self, e: OpExpr) -> Type: and isinstance(proper_left_type, Instance) and proper_left_type.type.fullname == "builtins.tuple" ): - proper_right_type = get_proper_type(self.accept(e.right)) if ( isinstance(proper_right_type, TupleType) and proper_right_type.partial_fallback.type.fullname == "builtins.tuple" @@ -3568,7 +3573,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: result, method_type = self.check_op( # The reverse operator here gives better error messages: operators.reverse_op_methods[method], - base_type=self.accept(e.right), + base_type=right_type, arg=e.left, context=e, allow_reverse=False, @@ -3580,6 +3585,63 @@ def visit_op_expr(self, e: OpExpr) -> Type: else: raise RuntimeError(f"Unknown operator {e.op}") + def literal_value_from_expr( + self, expr: Expression, typ: Type + ) -> tuple[list[Any], str, bool] | None: + if isinstance(expr, StrExpr): + return [expr.value], "builtins.str", False + if isinstance(expr, IntExpr): + return [expr.value], "builtins.int", False + if isinstance(expr, BytesExpr): + return [expr.value], "builtins.bytes", False + + ptype = get_proper_type(typ) + + if isinstance(ptype, LiteralType) and not isinstance(ptype.value, (bool, float)): + return [ptype.value], ptype.fallback.type.fullname, True + + if isinstance(ptype, UnionType): + fallback: str | None = None + values: list[str | int] = [] + for item in ptype.items: + pitem = get_proper_type(item) + if not isinstance(pitem, LiteralType) or isinstance(pitem.value, (float, bool)): + break + if fallback is None: + fallback = pitem.fallback.type.fullname + if fallback != pitem.fallback.type.fullname: + break + values.append(pitem.value) + else: # no break + assert fallback is not None + return values, fallback, True + return None + + def literal_expression_addition( + self, e: OpExpr, left_type: Type, right_type: Type + ) -> Type | None: + """Check if literal values can be combined with addition.""" + assert e.op == "+" + if not (lvalue := self.literal_value_from_expr(e.left, left_type)): + return None + if ( + not (rvalue := self.literal_value_from_expr(e.right, right_type)) + or lvalue[1] != rvalue[1] # different fallback + or lvalue[2] + rvalue[2] == 0 # no LiteralType + ): + return None + + values: list[int | str] = sorted( + {val[0] + val[1] for val in itertools.product(lvalue[0], rvalue[0])} + ) + if len(values) == 1: + return LiteralType(values[0], self.named_type(lvalue[1])) + elif len(values) > MAX_LITERAL_ADDITION_VALUES: + return None + return make_simplified_union( + [LiteralType(val, self.named_type(lvalue[1])) for val in values] + ) + def visit_comparison_expr(self, e: ComparisonExpr) -> Type: """Type check a comparison expression. diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index f36eff28f33f..dfeca8940a00 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1407,22 +1407,23 @@ c: Literal[4] d: Literal['foo'] e: str -reveal_type(a + a) # N: Revealed type is "builtins.int" +reveal_type(a + a) # N: Revealed type is "Literal[6]" reveal_type(a + b) # N: Revealed type is "builtins.int" reveal_type(b + a) # N: Revealed type is "builtins.int" -reveal_type(a + 1) # N: Revealed type is "builtins.int" -reveal_type(1 + a) # N: Revealed type is "builtins.int" -reveal_type(a + c) # N: Revealed type is "builtins.int" -reveal_type(c + a) # N: Revealed type is "builtins.int" +reveal_type(a + 1) # N: Revealed type is "Literal[4]" +reveal_type(1 + a) # N: Revealed type is "Literal[4]" +reveal_type(a + c) # N: Revealed type is "Literal[7]" +reveal_type(c + a) # N: Revealed type is "Literal[7]" -reveal_type(d + d) # N: Revealed type is "builtins.str" +reveal_type(d + d) # N: Revealed type is "Literal['foofoo']" reveal_type(d + e) # N: Revealed type is "builtins.str" reveal_type(e + d) # N: Revealed type is "builtins.str" -reveal_type(d + 'foo') # N: Revealed type is "builtins.str" -reveal_type('foo' + d) # N: Revealed type is "builtins.str" +reveal_type(d + 'bar') # N: Revealed type is "Literal['foobar']" +reveal_type('bar' + d) # N: Revealed type is "Literal['barfoo']" reveal_type(a.__add__(b)) # N: Revealed type is "builtins.int" reveal_type(b.__add__(a)) # N: Revealed type is "builtins.int" +reveal_type(a.__add__(a)) # N: Revealed type is "builtins.int" a *= b # E: Incompatible types in assignment (expression has type "int", variable has type "Literal[3]") b *= a @@ -2976,3 +2977,124 @@ x: Type[Literal[1]] # E: Type[...] can't contain "Literal[...]" y: Type[Union[Literal[1], Literal[2]]] # E: Type[...] can't contain "Union[Literal[...], Literal[...]]" z: Type[Literal[1, 2]] # E: Type[...] can't contain "Union[Literal[...], Literal[...]]" [builtins fixtures/tuple.pyi] + +[case testLiteralAddition] +from typing import Any, Union +from typing_extensions import Literal + +class A: + def __add__(self, other: str) -> str: ... + def __radd__(self, other: str) -> str: ... + +str_a: Literal["a"] +str_b: Literal["b"] +str_union_1: Literal["a", "b"] +str_union_2: Literal["d", "c"] +str_union_mixed_1: Union[Literal["a"], Any] +str_union_mixed_2: Union[Literal["a"], A] +s: str +int_1: Literal[1] +int_2: Literal[2] +int_union_1: Literal[1, 2] +int_union_2: Literal[4, 3] +i: int +bytes_a: Literal[b"a"] +bytes_b: Literal[b"b"] +bytes_union_1: Literal[b"a", b"b"] +bytes_union_2: Literal[b"d", b"c"] +b: bytes + +misc_union: Literal["a", 1] + +reveal_type("a" + "b") # N: Revealed type is "builtins.str" +reveal_type(str_a + str_b) # N: Revealed type is "Literal['ab']" +reveal_type(str_a + "b") # N: Revealed type is "Literal['ab']" +reveal_type("a" + str_b) # N: Revealed type is "Literal['ab']" +reveal_type(str_union_1 + "b") # N: Revealed type is "Union[Literal['ab'], Literal['bb']]" +reveal_type(str_union_1 + str_b) # N: Revealed type is "Union[Literal['ab'], Literal['bb']]" +reveal_type("a" + str_union_1) # N: Revealed type is "Union[Literal['aa'], Literal['ab']]" +reveal_type(str_a + str_union_1) # N: Revealed type is "Union[Literal['aa'], Literal['ab']]" +reveal_type(str_union_1 + str_union_2) # N: Revealed type is "Union[Literal['ac'], Literal['ad'], Literal['bc'], Literal['bd']]" +reveal_type(str_a + s) # N: Revealed type is "builtins.str" +reveal_type(s + str_a) # N: Revealed type is "builtins.str" +reveal_type(str_union_1 + s) # N: Revealed type is "builtins.str" +reveal_type(s + str_union_1) # N: Revealed type is "builtins.str" +reveal_type(str_a + str_union_mixed_1) # N: Revealed type is "builtins.str" +reveal_type(str_union_mixed_1 + str_a) # N: Revealed type is "Union[builtins.str, Any]" +reveal_type(str_a + str_union_mixed_2) # N: Revealed type is "builtins.str" +reveal_type(str_union_mixed_2 + str_a) # N: Revealed type is "builtins.str" + +reveal_type(1 + 2) # N: Revealed type is "builtins.int" +reveal_type(int_1 + int_2) # N: Revealed type is "Literal[3]" +reveal_type(int_1 + 1) # N: Revealed type is "Literal[2]" +reveal_type(1 + int_1) # N: Revealed type is "Literal[2]" +reveal_type(int_union_1 + 1) # N: Revealed type is "Union[Literal[2], Literal[3]]" +reveal_type(int_union_1 + int_1) # N: Revealed type is "Union[Literal[2], Literal[3]]" +reveal_type(1 + int_union_1) # N: Revealed type is "Union[Literal[2], Literal[3]]" +reveal_type(int_1 + int_union_1) # N: Revealed type is "Union[Literal[2], Literal[3]]" +reveal_type(int_union_1 + int_union_2) # N: Revealed type is "Union[Literal[4], Literal[5], Literal[6]]" +reveal_type(int_1 + i) # N: Revealed type is "builtins.int" +reveal_type(i + int_1) # N: Revealed type is "builtins.int" +reveal_type(int_union_1 + i) # N: Revealed type is "builtins.int" +reveal_type(i + int_union_1) # N: Revealed type is "builtins.int" + +reveal_type(b"a" + b"b") # N: Revealed type is "builtins.bytes" +reveal_type(bytes_a + bytes_b) # N: Revealed type is "Literal[b'ab']" +reveal_type(bytes_a + b"b") # N: Revealed type is "Literal[b'ab']" +reveal_type(b"a" + bytes_b) # N: Revealed type is "Literal[b'ab']" +reveal_type(bytes_union_1 + b"b") # N: Revealed type is "Union[Literal[b'ab'], Literal[b'bb']]" +reveal_type(bytes_union_1 + bytes_b) # N: Revealed type is "Union[Literal[b'ab'], Literal[b'bb']]" +reveal_type(b"a" + bytes_union_1) # N: Revealed type is "Union[Literal[b'aa'], Literal[b'ab']]" +reveal_type(bytes_a + bytes_union_1) # N: Revealed type is "Union[Literal[b'aa'], Literal[b'ab']]" +reveal_type(bytes_union_1 + bytes_union_2) # N: Revealed type is "Union[Literal[b'ac'], Literal[b'ad'], Literal[b'bc'], Literal[b'bd']]" +reveal_type(bytes_a + b) # N: Revealed type is "builtins.bytes" +reveal_type(b + bytes_a) # N: Revealed type is "builtins.bytes" +reveal_type(bytes_union_1 + b) # N: Revealed type is "builtins.bytes" +reveal_type(b + bytes_union_1) # N: Revealed type is "builtins.bytes" + +reveal_type(misc_union + "a") # N: Revealed type is "Union[builtins.str, builtins.int]" \ + # E: Unsupported operand types for + ("Literal[1]" and "str") \ + # N: Left operand is of type "Literal['a', 1]" +reveal_type("a" + misc_union) # E: Unsupported operand types for + ("str" and "Literal[1]") \ + # N: Right operand is of type "Literal['a', 1]" \ + # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] + +[case testLiteralAdditionInheritance] +class A: + a = "" + +class B(A): + a = "a" + "b" + +class C: + a = "a" + "b" + +reveal_type(A.a) # N: Revealed type is "builtins.str" +reveal_type(B.a) # N: Revealed type is "builtins.str" +reveal_type(C.a) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] + +[case testLiteralAdditionTypedDict] +from typing import TypedDict +from typing_extensions import Literal + +class LookupDict(TypedDict): + top_var: str + bottom_var: str + var: str + +def func(d: LookupDict, pos: Literal["top_", "bottom_", ""]) -> str: + return d[pos + "var"] + +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testLiteralAdditionGuardMaxValues] +from typing_extensions import Literal + +HexDigit = Literal["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F"] + +def foo(a: HexDigit, b: HexDigit, c: HexDigit) -> None: + reveal_type(a + b + c) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 2f8623c79b9f..de74e3b22473 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -34,6 +34,7 @@ class str(Sequence[str]): def __getitem__(self, item: int) -> str: pass def format(self, *args: object, **kwargs: object) -> str: pass class bytes(Sequence[int]): + def __add__(self, x: bytes) -> bytes: pass def __iter__(self) -> Iterator[int]: pass def __contains__(self, other: object) -> bool: pass def __getitem__(self, item: int) -> int: pass