Skip to content

Commit 1d95c83

Browse files
special case TypedDict.get
1 parent a07abb6 commit 1d95c83

File tree

3 files changed

+84
-11
lines changed

3 files changed

+84
-11
lines changed

mypy/checkmember.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
TypedDictType,
6969
TypeOfAny,
7070
TypeType,
71+
TypeVarId,
7172
TypeVarLikeType,
7273
TypeVarTupleType,
7374
TypeVarType,
@@ -1400,6 +1401,67 @@ def analyze_typeddict_access(
14001401
fallback=mx.chk.named_type("builtins.function"),
14011402
name=name,
14021403
)
1404+
elif name == "get":
1405+
# synthesize TypedDict.get() overloads
1406+
t = TypeVarType(
1407+
"T",
1408+
"T",
1409+
id=TypeVarId(-1),
1410+
values=[],
1411+
upper_bound=mx.chk.named_type("builtins.object"),
1412+
default=AnyType(TypeOfAny.from_omitted_generics),
1413+
)
1414+
str_type = mx.chk.named_type("builtins.str")
1415+
fn_type = mx.chk.named_type("builtins.function")
1416+
object_type = mx.chk.named_type("builtins.object")
1417+
1418+
overloads: list[CallableType] = []
1419+
# add two overloads per TypedDictType spec
1420+
for key, val in typ.items.items():
1421+
# first overload: def(Literal[key]) -> val
1422+
no_default = CallableType(
1423+
arg_types=[LiteralType(key, fallback=str_type)],
1424+
arg_kinds=[ARG_POS],
1425+
arg_names=[None],
1426+
ret_type=val,
1427+
fallback=fn_type,
1428+
name=name,
1429+
)
1430+
# second Overload: def [T] (Literal[key], default: T | Val, /) -> T | Val
1431+
with_default = CallableType(
1432+
variables=[t],
1433+
arg_types=[LiteralType(key, fallback=str_type), UnionType.make_union([val, t])],
1434+
arg_kinds=[ARG_POS, ARG_POS],
1435+
arg_names=[None, None],
1436+
ret_type=UnionType.make_union([val, t]),
1437+
fallback=fn_type,
1438+
name=name,
1439+
)
1440+
overloads.append(no_default)
1441+
overloads.append(with_default)
1442+
1443+
# finally, add fallback overloads when a key is used that is not in the TypedDict
1444+
# def (str) -> object
1445+
fallback_no_default = CallableType(
1446+
arg_types=[str_type],
1447+
arg_kinds=[ARG_POS],
1448+
arg_names=[None],
1449+
ret_type=object_type,
1450+
fallback=fn_type,
1451+
name=name,
1452+
)
1453+
# def (str, object) -> object
1454+
fallback_with_default = CallableType(
1455+
arg_types=[str_type, object_type],
1456+
arg_kinds=[ARG_POS, ARG_POS],
1457+
arg_names=[None, None],
1458+
ret_type=object_type,
1459+
fallback=fn_type,
1460+
name=name,
1461+
)
1462+
overloads.append(fallback_no_default)
1463+
overloads.append(fallback_with_default)
1464+
return Overloaded(overloads)
14031465
return _analyze_member_access(name, typ.fallback, mx, override_info)
14041466

14051467

test-data/unit/check-typeddict.test

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ class A: pass
10161016
D = TypedDict('D', {'x': List[int], 'y': int})
10171017
d: D
10181018
reveal_type(d.get('x', [])) # N: Revealed type is "builtins.list[builtins.int]"
1019-
d.get('x', ['x']) # E: List item 0 has incompatible type "str"; expected "int"
1019+
reveal_type(d.get('x', ['x'])) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]"
10201020
a = ['']
10211021
reveal_type(d.get('x', a)) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]"
10221022
[builtins fixtures/dict.pyi]
@@ -1026,14 +1026,22 @@ reveal_type(d.get('x', a)) # N: Revealed type is "Union[builtins.list[builtins.i
10261026
from typing import TypedDict
10271027
D = TypedDict('D', {'x': int, 'y': str})
10281028
d: D
1029-
d.get() # E: All overload variants of "get" of "Mapping" require at least one argument \
1029+
d.get() # E: All overload variants of "get" require at least one argument \
10301030
# N: Possible overload variants: \
1031-
# N: def get(self, k: str) -> object \
1032-
# N: def [V] get(self, k: str, default: object) -> object
1033-
d.get('x', 1, 2) # E: No overload variant of "get" of "Mapping" matches argument types "str", "int", "int" \
1031+
# N: def get(Literal['x'], /) -> int \
1032+
# N: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T] \
1033+
# N: def get(Literal['y'], /) -> str \
1034+
# N: def [T] get(Literal['y'], Union[str, T], /) -> Union[str, T] \
1035+
# N: def get(str, /) -> object \
1036+
# N: def get(str, object, /) -> object
1037+
d.get('x', 1, 2) # E: No overload variant of "get" matches argument types "str", "int", "int" \
10341038
# N: Possible overload variants: \
1035-
# N: def get(self, k: str) -> object \
1036-
# N: def [V] get(self, k: str, default: Union[int, V]) -> object
1039+
# N: def get(Literal['x'], /) -> int \
1040+
# N: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T] \
1041+
# N: def get(Literal['y'], /) -> str \
1042+
# N: def [T] get(Literal['y'], Union[int, T], /) -> Union[str, T] \
1043+
# N: def get(str, /) -> object \
1044+
# N: def get(str, object, /) -> object
10371045
x = d.get('z')
10381046
reveal_type(x) # N: Revealed type is "builtins.object"
10391047
s = ''

test-data/unit/pythoneval.test

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,11 +1046,14 @@ reveal_type(d.get(s))
10461046
_testTypedDictGet.py:6: note: Revealed type is "Union[builtins.int, None]"
10471047
_testTypedDictGet.py:7: note: Revealed type is "Union[builtins.str, None]"
10481048
_testTypedDictGet.py:8: note: Revealed type is "builtins.object"
1049-
_testTypedDictGet.py:9: error: All overload variants of "get" of "Mapping" require at least one argument
1049+
_testTypedDictGet.py:9: error: All overload variants of "get" require at least one argument
10501050
_testTypedDictGet.py:9: note: Possible overload variants:
1051-
_testTypedDictGet.py:9: note: def get(self, str, /) -> object
1052-
_testTypedDictGet.py:9: note: def get(self, str, /, default: object) -> object
1053-
_testTypedDictGet.py:9: note: def [_T] get(self, str, /, default: _T) -> object
1051+
_testTypedDictGet.py:9: note: def get(Literal['x'], /) -> int
1052+
_testTypedDictGet.py:9: note: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T]
1053+
_testTypedDictGet.py:9: note: def get(Literal['y'], /) -> str
1054+
_testTypedDictGet.py:9: note: def [T] get(Literal['y'], Union[str, T], /) -> Union[str, T]
1055+
_testTypedDictGet.py:9: note: def get(str, /) -> object
1056+
_testTypedDictGet.py:9: note: def get(str, object, /) -> object
10541057
_testTypedDictGet.py:11: note: Revealed type is "builtins.object"
10551058

10561059
[case testTypedDictMappingMethods]

0 commit comments

Comments
 (0)