Skip to content

Commit 76ea9a2

Browse files
committed
Add support for literal addition
1 parent b16c192 commit 76ea9a2

File tree

5 files changed

+150
-11
lines changed

5 files changed

+150
-11
lines changed

mypy/checkexpr.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3445,6 +3445,9 @@ def visit_op_expr(self, e: OpExpr) -> Type:
34453445
items=proper_left_type.items + [UnpackType(mapped)]
34463446
)
34473447

3448+
if e.op == "+" and (result := self.literal_expression_addition(e, left_type)):
3449+
return result
3450+
34483451
use_reverse: UseReverse = USE_REVERSE_DEFAULT
34493452
if e.op == "|":
34503453
if is_named_instance(proper_left_type, "builtins.dict"):
@@ -3505,6 +3508,57 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35053508
else:
35063509
raise RuntimeError(f"Unknown operator {e.op}")
35073510

3511+
def literal_value_from_expr(
3512+
self, expr: Expression, typ: Type | None = None
3513+
) -> tuple[list[str | int], str] | None:
3514+
if isinstance(expr, StrExpr):
3515+
return [expr.value], "builtins.str"
3516+
if isinstance(expr, IntExpr):
3517+
return [expr.value], "builtins.int"
3518+
if isinstance(expr, BytesExpr):
3519+
return [expr.value], "builtins.bytes"
3520+
3521+
typ = typ or self.accept(expr)
3522+
ptype = get_proper_type(typ)
3523+
3524+
if isinstance(ptype, LiteralType) and not isinstance(ptype.value, (bool, float)):
3525+
return [ptype.value], ptype.fallback.type.fullname
3526+
3527+
if isinstance(ptype, UnionType):
3528+
fallback: str | None = None
3529+
values: list[str | int] = []
3530+
for item in ptype.items:
3531+
pitem = get_proper_type(item)
3532+
if not isinstance(pitem, LiteralType) or isinstance(pitem.value, (float, bool)):
3533+
break
3534+
if fallback is None:
3535+
fallback = pitem.fallback.type.fullname
3536+
if fallback != pitem.fallback.type.fullname:
3537+
break
3538+
values.append(pitem.value)
3539+
else:
3540+
assert fallback is not None
3541+
return values, fallback
3542+
return None
3543+
3544+
def literal_expression_addition(self, e: OpExpr, left_type: Type) -> Type | None:
3545+
"""Check if literal values can be combined with addition."""
3546+
assert e.op == "+"
3547+
if not (lvalue := self.literal_value_from_expr(e.left, left_type)):
3548+
return None
3549+
if not (rvalue := self.literal_value_from_expr(e.right)) or lvalue[1] != rvalue[1]:
3550+
return None
3551+
3552+
values: list[int | str] = sorted(
3553+
{
3554+
val[0] + val[1] # type: ignore[operator]
3555+
for val in itertools.product(lvalue[0], rvalue[0])
3556+
}
3557+
)
3558+
if len(values) == 1:
3559+
return LiteralType(values[0], self.named_type(lvalue[1]))
3560+
return UnionType([LiteralType(val, self.named_type(lvalue[1])) for val in values])
3561+
35083562
def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
35093563
"""Type check a comparison expression.
35103564

test-data/unit/check-literal.test

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,19 +1416,19 @@ c: Literal[4]
14161416
d: Literal['foo']
14171417
e: str
14181418

1419-
reveal_type(a + a) # N: Revealed type is "builtins.int"
1419+
reveal_type(a + a) # N: Revealed type is "Literal[6]"
14201420
reveal_type(a + b) # N: Revealed type is "builtins.int"
14211421
reveal_type(b + a) # N: Revealed type is "builtins.int"
1422-
reveal_type(a + 1) # N: Revealed type is "builtins.int"
1423-
reveal_type(1 + a) # N: Revealed type is "builtins.int"
1424-
reveal_type(a + c) # N: Revealed type is "builtins.int"
1425-
reveal_type(c + a) # N: Revealed type is "builtins.int"
1422+
reveal_type(a + 1) # N: Revealed type is "Literal[4]"
1423+
reveal_type(1 + a) # N: Revealed type is "Literal[4]"
1424+
reveal_type(a + c) # N: Revealed type is "Literal[7]"
1425+
reveal_type(c + a) # N: Revealed type is "Literal[7]"
14261426

1427-
reveal_type(d + d) # N: Revealed type is "builtins.str"
1427+
reveal_type(d + d) # N: Revealed type is "Literal['foofoo']"
14281428
reveal_type(d + e) # N: Revealed type is "builtins.str"
14291429
reveal_type(e + d) # N: Revealed type is "builtins.str"
1430-
reveal_type(d + 'foo') # N: Revealed type is "builtins.str"
1431-
reveal_type('foo' + d) # N: Revealed type is "builtins.str"
1430+
reveal_type(d + 'bar') # N: Revealed type is "Literal['foobar']"
1431+
reveal_type('bar' + d) # N: Revealed type is "Literal['barfoo']"
14321432

14331433
reveal_type(a.__add__(b)) # N: Revealed type is "builtins.int"
14341434
reveal_type(b.__add__(a)) # N: Revealed type is "builtins.int"
@@ -2960,3 +2960,87 @@ class C(B[Literal["word"]]):
29602960
reveal_type(C().collection) # N: Revealed type is "builtins.list[Literal['word']]"
29612961
reveal_type(C().word) # N: Revealed type is "Literal['word']"
29622962
[builtins fixtures/tuple.pyi]
2963+
2964+
[case testLiteralAddition]
2965+
from typing import Union
2966+
from typing_extensions import Literal
2967+
2968+
str_a: Literal["a"]
2969+
str_b: Literal["b"]
2970+
str_union_1: Literal["a", "b"]
2971+
str_union_2: Literal["c", "d"]
2972+
s: str
2973+
int_1: Literal[1]
2974+
int_2: Literal[2]
2975+
int_union_1: Literal[1, 2]
2976+
int_union_2: Literal[3, 4]
2977+
i: int
2978+
bytes_a: Literal[b"a"]
2979+
bytes_b: Literal[b"b"]
2980+
bytes_union_1: Literal[b"a", b"b"]
2981+
bytes_union_2: Literal[b"c", b"d"]
2982+
b: bytes
2983+
2984+
misc_union: Literal["a", 1]
2985+
2986+
reveal_type(str_a + str_b) # N: Revealed type is "Literal['ab']"
2987+
reveal_type(str_a + "b") # N: Revealed type is "Literal['ab']"
2988+
reveal_type("a" + str_b) # N: Revealed type is "Literal['ab']"
2989+
reveal_type(str_union_1 + "b") # N: Revealed type is "Union[Literal['ab'], Literal['bb']]"
2990+
reveal_type(str_union_1 + str_b) # N: Revealed type is "Union[Literal['ab'], Literal['bb']]"
2991+
reveal_type("a" + str_union_1) # N: Revealed type is "Union[Literal['aa'], Literal['ab']]"
2992+
reveal_type(str_a + str_union_1) # N: Revealed type is "Union[Literal['aa'], Literal['ab']]"
2993+
reveal_type(str_union_1 + str_union_2) # N: Revealed type is "Union[Literal['ac'], Literal['ad'], Literal['bc'], Literal['bd']]"
2994+
reveal_type(str_a + s) # N: Revealed type is "builtins.str"
2995+
reveal_type(s + str_a) # N: Revealed type is "builtins.str"
2996+
reveal_type(str_union_1 + s) # N: Revealed type is "builtins.str"
2997+
reveal_type(s + str_union_1) # N: Revealed type is "builtins.str"
2998+
2999+
reveal_type(int_1 + int_2) # N: Revealed type is "Literal[3]"
3000+
reveal_type(int_1 + 1) # N: Revealed type is "Literal[2]"
3001+
reveal_type(1 + int_1) # N: Revealed type is "Literal[2]"
3002+
reveal_type(int_union_1 + 1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3003+
reveal_type(int_union_1 + int_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3004+
reveal_type(1 + int_union_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3005+
reveal_type(int_1 + int_union_1) # N: Revealed type is "Union[Literal[2], Literal[3]]"
3006+
reveal_type(int_union_1 + int_union_2) # N: Revealed type is "Union[Literal[4], Literal[5], Literal[6]]"
3007+
reveal_type(int_1 + i) # N: Revealed type is "builtins.int"
3008+
reveal_type(i + int_1) # N: Revealed type is "builtins.int"
3009+
reveal_type(int_union_1 + i) # N: Revealed type is "builtins.int"
3010+
reveal_type(i + int_union_1) # N: Revealed type is "builtins.int"
3011+
3012+
reveal_type(bytes_a + bytes_b) # N: Revealed type is "Literal[b'ab']"
3013+
reveal_type(bytes_a + b"b") # N: Revealed type is "Literal[b'ab']"
3014+
reveal_type(b"a" + bytes_b) # N: Revealed type is "Literal[b'ab']"
3015+
reveal_type(bytes_union_1 + b"b") # N: Revealed type is "Union[Literal[b'ab'], Literal[b'bb']]"
3016+
reveal_type(bytes_union_1 + bytes_b) # N: Revealed type is "Union[Literal[b'ab'], Literal[b'bb']]"
3017+
reveal_type(b"a" + bytes_union_1) # N: Revealed type is "Union[Literal[b'aa'], Literal[b'ab']]"
3018+
reveal_type(bytes_a + bytes_union_1) # N: Revealed type is "Union[Literal[b'aa'], Literal[b'ab']]"
3019+
reveal_type(bytes_union_1 + bytes_union_2) # N: Revealed type is "Union[Literal[b'ac'], Literal[b'ad'], Literal[b'bc'], Literal[b'bd']]"
3020+
reveal_type(bytes_a + b) # N: Revealed type is "builtins.bytes"
3021+
reveal_type(b + bytes_a) # N: Revealed type is "builtins.bytes"
3022+
reveal_type(bytes_union_1 + b) # N: Revealed type is "builtins.bytes"
3023+
reveal_type(b + bytes_union_1) # N: Revealed type is "builtins.bytes"
3024+
3025+
reveal_type(misc_union + "a") # N: Revealed type is "Union[builtins.str, builtins.int]" \
3026+
# E: Unsupported operand types for + ("Literal[1]" and "str") \
3027+
# N: Left operand is of type "Literal['a', 1]"
3028+
reveal_type("a" + misc_union) # E: Unsupported operand types for + ("str" and "Literal[1]") \
3029+
# N: Right operand is of type "Literal['a', 1]" \
3030+
# N: Revealed type is "builtins.str"
3031+
[builtins fixtures/primitives.pyi]
3032+
3033+
[case testLiteralAdditionTypedDict]
3034+
from typing import TypedDict
3035+
from typing_extensions import Literal
3036+
3037+
class LookupDict(TypedDict):
3038+
top_var: str
3039+
bottom_var: str
3040+
var: str
3041+
3042+
def func(d: LookupDict, pos: Literal["top_", "bottom_", ""]) -> str:
3043+
return d[pos + "var"]
3044+
3045+
[builtins fixtures/dict.pyi]
3046+
[typing fixtures/typing-typeddict.pyi]

test-data/unit/cmdline.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,8 +912,8 @@ test_between(1 + 1)
912912
tabs.py:2: error: Incompatible return value type (got "None", expected "str")
913913
return None
914914
^~~~
915-
tabs.py:4: error: Argument 1 to "test_between" has incompatible type "int";
916-
expected "str"
915+
tabs.py:4: error: Argument 1 to "test_between" has incompatible type
916+
"Literal[2]"; expected "str"
917917
test_between(1 + 1)
918918
^~~~~~~~~~~~
919919

test-data/unit/fixtures/primitives.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class str(Sequence[str]):
3232
def __getitem__(self, item: int) -> str: pass
3333
def format(self, *args: object, **kwargs: object) -> str: pass
3434
class bytes(Sequence[int]):
35+
def __add__(self, x: bytes) -> bytes: pass
3536
def __iter__(self) -> Iterator[int]: pass
3637
def __contains__(self, other: object) -> bool: pass
3738
def __getitem__(self, item: int) -> int: pass

test-data/unit/typexport-basic.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class str: pass
142142
class list: pass
143143
class dict: pass
144144
[out]
145-
OpExpr(3) : builtins.int
145+
OpExpr(3) : Literal[3]
146146
OpExpr(4) : builtins.float
147147
OpExpr(5) : builtins.float
148148
OpExpr(6) : builtins.float

0 commit comments

Comments
 (0)