Skip to content

Commit 27e50e2

Browse files
improved TypedDict.get inference
1 parent feeb3f0 commit 27e50e2

File tree

6 files changed

+315
-101
lines changed

6 files changed

+315
-101
lines changed

mypy/checkmember.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mypy.meet import is_overlapping_types
1818
from mypy.messages import MessageBuilder
1919
from mypy.nodes import (
20+
ARG_OPT,
2021
ARG_POS,
2122
ARG_STAR,
2223
ARG_STAR2,
@@ -68,6 +69,7 @@
6869
TypedDictType,
6970
TypeOfAny,
7071
TypeType,
72+
TypeVarId,
7173
TypeVarLikeType,
7274
TypeVarTupleType,
7375
TypeVarType,
@@ -1402,6 +1404,87 @@ def analyze_typeddict_access(
14021404
fallback=mx.chk.named_type("builtins.function"),
14031405
name=name,
14041406
)
1407+
elif name == "get":
1408+
# synthesize TypedDict.get() overloads
1409+
str_type = mx.chk.named_type("builtins.str")
1410+
fn_type = mx.chk.named_type("builtins.function")
1411+
object_type = mx.chk.named_type("builtins.object")
1412+
1413+
# type variable for default value
1414+
tvar = TypeVarType(
1415+
"T",
1416+
"T",
1417+
id=TypeVarId(-1),
1418+
values=[],
1419+
upper_bound=object_type,
1420+
default=AnyType(TypeOfAny.from_omitted_generics),
1421+
)
1422+
# generate the overloads
1423+
overloads: list[CallableType] = []
1424+
for key, value_type in typ.items.items():
1425+
key_type = LiteralType(key, fallback=str_type)
1426+
1427+
if key in typ.required_keys:
1428+
# If the key is required, we know it must be present in the TypedDict.
1429+
overload = CallableType(
1430+
arg_types=[key_type, object_type],
1431+
arg_kinds=[ARG_POS, ARG_OPT],
1432+
arg_names=[None, None],
1433+
ret_type=value_type,
1434+
fallback=fn_type,
1435+
name=name,
1436+
)
1437+
overloads.append(overload)
1438+
else:
1439+
# The key is not required, but if it is present, we know its type.
1440+
# def (K) -> V | None
1441+
overload = CallableType(
1442+
arg_types=[key_type],
1443+
arg_kinds=[ARG_POS],
1444+
arg_names=[None],
1445+
ret_type=UnionType.make_union([value_type, NoneType()]),
1446+
fallback=fn_type,
1447+
name=name,
1448+
)
1449+
overloads.append(overload)
1450+
1451+
# We add an extra overload for the case when the given default is a subtype of the value type.
1452+
# This makes sure that the return type is inferred as the value type instead of a union.
1453+
# def (K, V) -> V
1454+
overload = CallableType(
1455+
arg_types=[key_type, value_type],
1456+
arg_kinds=[ARG_POS, ARG_POS],
1457+
arg_names=[None, None],
1458+
ret_type=value_type,
1459+
fallback=fn_type,
1460+
name=name,
1461+
)
1462+
overloads.append(overload)
1463+
1464+
# fallback: def [T](K, T) -> V | T
1465+
overload = CallableType(
1466+
variables=[tvar],
1467+
arg_types=[key_type, tvar],
1468+
arg_kinds=[ARG_POS, ARG_POS],
1469+
arg_names=[None, None],
1470+
ret_type=UnionType.make_union([value_type, tvar]),
1471+
fallback=fn_type,
1472+
name=name,
1473+
)
1474+
overloads.append(overload)
1475+
1476+
# finally, add fallback overload when a key is used that is not in the TypedDict
1477+
# def (str, object=...) -> object
1478+
fallback_overload = CallableType(
1479+
arg_types=[str_type, object_type],
1480+
arg_kinds=[ARG_POS, ARG_OPT],
1481+
arg_names=[None, None],
1482+
ret_type=object_type,
1483+
fallback=fn_type,
1484+
name=name,
1485+
)
1486+
overloads.append(fallback_overload)
1487+
return Overloaded(overloads)
14051488
return _analyze_member_access(name, typ.fallback, mx, override_info)
14061489

14071490

mypy/plugins/default.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import mypy.errorcodes as codes
77
from mypy import message_registry
8-
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
8+
from mypy.nodes import IntExpr, StrExpr, UnaryExpr
99
from mypy.plugin import (
1010
AttributeContext,
1111
ClassDefContext,
@@ -75,6 +75,7 @@
7575
TypedDictType,
7676
TypeOfAny,
7777
TypeVarType,
78+
UninhabitedType,
7879
UnionType,
7980
get_proper_type,
8081
get_proper_types,
@@ -120,9 +121,9 @@ def get_function_signature_hook(
120121
def get_method_signature_hook(
121122
self, fullname: str
122123
) -> Callable[[MethodSigContext], FunctionLike] | None:
123-
if fullname == "typing.Mapping.get":
124-
return typed_dict_get_signature_callback
125-
elif fullname in TD_SETDEFAULT_NAMES:
124+
# NOTE: signatures for `__setitem__`, `__delitem__` and `get` are
125+
# defined in checkmember.py/analyze_typeddict_access
126+
if fullname in TD_SETDEFAULT_NAMES:
126127
return typed_dict_setdefault_signature_callback
127128
elif fullname in TD_POP_NAMES:
128129
return typed_dict_pop_signature_callback
@@ -212,46 +213,6 @@ def get_class_decorator_hook_2(
212213
return None
213214

214215

215-
def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
216-
"""Try to infer a better signature type for TypedDict.get.
217-
218-
This is used to get better type context for the second argument that
219-
depends on a TypedDict value type.
220-
"""
221-
signature = ctx.default_signature
222-
if (
223-
isinstance(ctx.type, TypedDictType)
224-
and len(ctx.args) == 2
225-
and len(ctx.args[0]) == 1
226-
and isinstance(ctx.args[0][0], StrExpr)
227-
and len(signature.arg_types) == 2
228-
and len(signature.variables) == 1
229-
and len(ctx.args[1]) == 1
230-
):
231-
key = ctx.args[0][0].value
232-
value_type = get_proper_type(ctx.type.items.get(key))
233-
ret_type = signature.ret_type
234-
if value_type:
235-
default_arg = ctx.args[1][0]
236-
if (
237-
isinstance(value_type, TypedDictType)
238-
and isinstance(default_arg, DictExpr)
239-
and len(default_arg.items) == 0
240-
):
241-
# Caller has empty dict {} as default for typed dict.
242-
value_type = value_type.copy_modified(required_keys=set())
243-
# Tweak the signature to include the value type as context. It's
244-
# only needed for type inference since there's a union with a type
245-
# variable that accepts everything.
246-
tv = signature.variables[0]
247-
assert isinstance(tv, TypeVarType)
248-
return signature.copy_modified(
249-
arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
250-
ret_type=ret_type,
251-
)
252-
return signature
253-
254-
255216
def typed_dict_get_callback(ctx: MethodContext) -> Type:
256217
"""Infer a precise return type for TypedDict.get with literal first argument."""
257218
if (
@@ -263,30 +224,41 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type:
263224
if keys is None:
264225
return ctx.default_return_type
265226

227+
default_type: Type
228+
if len(ctx.arg_types) <= 1 or not ctx.arg_types[1]:
229+
default_type = NoneType()
230+
elif len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
231+
default_type = ctx.arg_types[1][0]
232+
else:
233+
return ctx.default_return_type
234+
266235
output_types: list[Type] = []
267236
for key in keys:
268-
value_type = get_proper_type(ctx.type.items.get(key))
237+
value_type: Type | None = ctx.type.items.get(key)
269238
if value_type is None:
270239
return ctx.default_return_type
271240

272-
if len(ctx.arg_types) == 1:
241+
if key in ctx.type.required_keys:
273242
output_types.append(value_type)
274-
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
275-
default_arg = ctx.args[1][0]
243+
else:
244+
# HACK to deal with get(key, {})
245+
proper_default = get_proper_type(default_type)
276246
if (
277-
isinstance(default_arg, DictExpr)
278-
and len(default_arg.items) == 0
279-
and isinstance(value_type, TypedDictType)
247+
isinstance(vt := get_proper_type(value_type), TypedDictType)
248+
and isinstance(proper_default, Instance)
249+
and proper_default.type.fullname == "builtins.dict"
250+
and len(proper_default.args) == 2
251+
and isinstance(get_proper_type(proper_default.args[0]), UninhabitedType)
252+
and isinstance(get_proper_type(proper_default.args[1]), UninhabitedType)
280253
):
281-
# Special case '{}' as the default for a typed dict type.
282-
output_types.append(value_type.copy_modified(required_keys=set()))
254+
output_types.append(vt.copy_modified(required_keys=set()))
283255
else:
284256
output_types.append(value_type)
285-
output_types.append(ctx.arg_types[1][0])
286-
287-
if len(ctx.arg_types) == 1:
288-
output_types.append(NoneType())
257+
output_types.append(default_type)
289258

259+
# for nicer reveal_type, put default at the end, if it is present
260+
if default_type in output_types:
261+
output_types = [t for t in output_types if t != default_type] + [default_type]
290262
return make_simplified_union(output_types)
291263
return ctx.default_return_type
292264

test-data/unit/check-literal.test

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ reveal_type(d[a_key]) # N: Revealed type is "builtins.int"
18841884
reveal_type(d[b_key]) # N: Revealed type is "builtins.str"
18851885
d[c_key] # E: TypedDict "Outer" has no key "c"
18861886

1887-
reveal_type(d.get(a_key, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
1887+
reveal_type(d.get(a_key, u)) # N: Revealed type is "builtins.int"
18881888
reveal_type(d.get(b_key, u)) # N: Revealed type is "Union[builtins.str, __main__.Unrelated]"
18891889
reveal_type(d.get(c_key, u)) # N: Revealed type is "builtins.object"
18901890

@@ -1928,7 +1928,7 @@ u: Unrelated
19281928
reveal_type(a[int_key_good]) # N: Revealed type is "builtins.int"
19291929
reveal_type(b[int_key_good]) # N: Revealed type is "builtins.int"
19301930
reveal_type(c[str_key_good]) # N: Revealed type is "builtins.int"
1931-
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
1931+
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "builtins.int"
19321932
reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "builtins.object"
19331933

19341934
a[int_key_bad] # E: Tuple index out of range
@@ -1993,8 +1993,8 @@ optional_keys: Literal["d", "e"]
19931993
bad_keys: Literal["a", "bad"]
19941994

19951995
reveal_type(test[good_keys]) # N: Revealed type is "Union[__main__.A, __main__.B]"
1996-
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B, None]"
1997-
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?, __main__.B]"
1996+
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B]"
1997+
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B]"
19981998
reveal_type(test.pop(optional_keys)) # N: Revealed type is "Union[__main__.D, __main__.E]"
19991999
reveal_type(test.pop(optional_keys, 3)) # N: Revealed type is "Union[__main__.D, __main__.E, Literal[3]?]"
20002000
reveal_type(test.setdefault(good_keys, AAndB())) # N: Revealed type is "Union[__main__.A, __main__.B]"
@@ -2037,15 +2037,18 @@ class D2(TypedDict):
20372037
d: D
20382038

20392039
x: Union[D1, D2]
2040-
bad_keys: Literal['a', 'b', 'c', 'd']
20412040
good_keys: Literal['b', 'c']
2041+
mixed_keys: Literal['a', 'b', 'c', 'd']
2042+
bad_keys: Literal['e', 'f']
20422043

2043-
x[bad_keys] # E: TypedDict "D1" has no key "d" \
2044+
x[mixed_keys] # E: TypedDict "D1" has no key "d" \
20442045
# E: TypedDict "D2" has no key "a"
20452046

20462047
reveal_type(x[good_keys]) # N: Revealed type is "Union[__main__.B, __main__.C]"
2047-
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C, None]"
2048-
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, Literal[3]?, __main__.C]"
2048+
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C]"
2049+
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, __main__.C]"
2050+
reveal_type(x.get(mixed_keys)) # N: Revealed type is "builtins.object"
2051+
reveal_type(x.get(mixed_keys, 3)) # N: Revealed type is "builtins.object"
20492052
reveal_type(x.get(bad_keys)) # N: Revealed type is "builtins.object"
20502053
reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "builtins.object"
20512054

test-data/unit/check-recursive-types.test

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,10 +690,11 @@ class TD(TypedDict, total=False):
690690
y: TD
691691

692692
td: TD
693+
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...}), None]"
693694
td["y"] = {"x": 0, "y": {}}
694695
td["y"] = {"x": 0, "y": {"x": 0, "y": 42}} # E: Incompatible types (expression has type "int", TypedDict item "y" has type "TD")
695696

696-
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...})}), None]"
697+
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...}), None]"
697698
s: str = td.get("y") # E: Incompatible types in assignment (expression has type "Optional[TD]", variable has type "str")
698699

699700
td.update({"x": 0, "y": {"x": 1, "y": {}}})

0 commit comments

Comments
 (0)