Skip to content

Commit 3025b00

Browse files
improved TypedDict.get inference
1 parent ae1ba04 commit 3025b00

File tree

5 files changed

+206
-108
lines changed

5 files changed

+206
-108
lines changed

mypy/checkmember.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,57 +1409,65 @@ def analyze_typeddict_access(
14091409
str_type = mx.chk.named_type("builtins.str")
14101410
fn_type = mx.chk.named_type("builtins.function")
14111411
object_type = mx.chk.named_type("builtins.object")
1412-
t = TypeVarType(
1412+
1413+
# type variable for default value
1414+
tvar = TypeVarType(
14131415
"T",
14141416
"T",
14151417
id=TypeVarId(-1),
14161418
values=[],
14171419
upper_bound=object_type,
14181420
default=AnyType(TypeOfAny.from_omitted_generics),
14191421
)
1422+
# generate the overloads
14201423
overloads: list[CallableType] = []
1421-
for key, val in typ.items.items():
1424+
for key, value_type in typ.items.items():
1425+
key_type = LiteralType(key, fallback=str_type)
1426+
14221427
if key in typ.required_keys:
14231428
# If the key is required, we know it must be present in the TypedDict.
14241429
overload = CallableType(
1425-
arg_types=[LiteralType(key, fallback=str_type), object_type],
1430+
arg_types=[key_type, object_type],
14261431
arg_kinds=[ARG_POS, ARG_OPT],
14271432
arg_names=[None, None],
1428-
ret_type=val,
1433+
ret_type=value_type,
14291434
fallback=fn_type,
14301435
name=name,
14311436
)
14321437
overloads.append(overload)
14331438
else:
1434-
# The key is not required, so we add the overloads:
1435-
# def (Literal[Key]) -> Val | None
1436-
# def (Literal[Key], default: Val) -> Val
1437-
# def [T] (Literal[Key], default: T = ..., /) -> Val | T
1438-
# TODO: simplify the last two overloads to just one
1439+
# The key is not required, but if it is present, we know its type.
1440+
# def (K) -> V | None
14391441
overload = CallableType(
1440-
arg_types=[LiteralType(key, fallback=str_type)],
1442+
arg_types=[key_type],
14411443
arg_kinds=[ARG_POS],
14421444
arg_names=[None],
1443-
ret_type=UnionType.make_union([val, NoneType()]),
1445+
ret_type=UnionType.make_union([value_type, NoneType()]),
14441446
fallback=fn_type,
14451447
name=name,
14461448
)
14471449
overloads.append(overload)
1450+
1451+
# We add an extra overload for the case when the given default is a subtype of the value type.
1452+
# This makes sure that the return type is inferred as the value type instead of a union.
1453+
# def (K, V) -> V
14481454
overload = CallableType(
1449-
arg_types=[LiteralType(key, fallback=str_type), val],
1455+
arg_types=[key_type, value_type],
14501456
arg_kinds=[ARG_POS, ARG_POS],
14511457
arg_names=[None, None],
1452-
ret_type=val,
1458+
ret_type=value_type,
14531459
fallback=fn_type,
14541460
name=name,
14551461
)
14561462
overloads.append(overload)
1463+
1464+
# fallback: def [T](K, T) -> V | T
14571465
overload = CallableType(
1458-
variables=[t],
1459-
arg_types=[LiteralType(key, fallback=str_type), t],
1460-
arg_kinds=[ARG_POS, ARG_OPT],
1466+
variables=[tvar],
1467+
arg_types=[key_type, tvar],
1468+
arg_kinds=[ARG_POS, ARG_POS],
14611469
arg_names=[None, None],
1462-
ret_type=UnionType.make_union([val, t]),
1470+
ret_type=UnionType.make_union([value_type, tvar]),
14631471
fallback=fn_type,
14641472
name=name,
14651473
)

mypy/plugins/default.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@
7070
Instance,
7171
LiteralType,
7272
NoneType,
73-
ProperType,
7473
TupleType,
7574
Type,
7675
TypedDictType,
7776
TypeOfAny,
7877
TypeVarType,
78+
UninhabitedType,
7979
UnionType,
8080
get_proper_type,
8181
get_proper_types,
@@ -122,7 +122,7 @@ def get_method_signature_hook(
122122
self, fullname: str
123123
) -> Callable[[MethodSigContext], FunctionLike] | None:
124124
# NOTE: signatures for `__setitem__`, `__delitem__` and `get` are
125-
# checkmember.py/analyze_typeddict_access
125+
# defined in checkmember.py/analyze_typeddict_access
126126
if fullname in TD_SETDEFAULT_NAMES:
127127
return typed_dict_setdefault_signature_callback
128128
elif fullname in TD_POP_NAMES:
@@ -214,41 +214,51 @@ def get_class_decorator_hook_2(
214214

215215

216216
def typed_dict_get_callback(ctx: MethodContext) -> Type:
217-
"""Infer a precise return type for TypedDict.get with union of literal first argument."""
217+
"""Infer a precise return type for TypedDict.get with literal first argument."""
218218
if (
219219
isinstance(ctx.type, TypedDictType)
220220
and len(ctx.arg_types) >= 1
221221
and len(ctx.arg_types[0]) == 1
222-
and isinstance(get_proper_type(ctx.arg_types[0][0]), UnionType)
223222
):
224223
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
225224
if keys is None:
226225
return ctx.default_return_type
227226

228-
output_types: list[Type] = []
229-
default_arg: Type
230-
227+
default_type: Type
231228
if len(ctx.arg_types) <= 1 or not ctx.arg_types[1]:
232-
default_arg = NoneType()
229+
default_type = NoneType()
233230
elif len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
234-
default_arg = ctx.arg_types[1][0]
231+
default_type = ctx.arg_types[1][0]
235232
else:
236233
return ctx.default_return_type
237234

235+
output_types: list[Type] = []
238236
for key in keys:
239-
value: ProperType | None = get_proper_type(ctx.type.items.get(key))
240-
if value is None:
237+
value_type: Type | None = ctx.type.items.get(key)
238+
if value_type is None:
241239
return ctx.default_return_type
242240

243241
if key in ctx.type.required_keys:
244-
output_types.append(value)
242+
output_types.append(value_type)
245243
else:
246-
output_types.append(value)
247-
output_types.append(default_arg)
244+
# HACK to deal with get(key, {})
245+
proper_default = get_proper_type(default_type)
246+
if (
247+
isinstance(vt := get_proper_type(value_type), TypedDictType)
248+
and isinstance(proper_default, Instance)
249+
and proper_default.type.fullname == "builtins.dict"
250+
and len(proper_default.args) == 2
251+
and isinstance(get_proper_type(proper_default.args[0]), UninhabitedType)
252+
and isinstance(get_proper_type(proper_default.args[1]), UninhabitedType)
253+
):
254+
output_types.append(vt.copy_modified(required_keys=set()))
255+
else:
256+
output_types.append(value_type)
257+
output_types.append(default_type)
248258

249259
# for nicer reveal_type, put default at the end, if it is present
250-
if default_arg in output_types:
251-
output_types = [t for t in output_types if t != default_arg] + [default_arg]
260+
if default_type in output_types:
261+
output_types = [t for t in output_types if t != default_type] + [default_type]
252262
return make_simplified_union(output_types)
253263
return ctx.default_return_type
254264

test-data/unit/check-recursive-types.test

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,7 @@ class TD(TypedDict, total=False):
690690
y: TD
691691

692692
td: TD
693+
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...}), None]"
693694
td["y"] = {"x": 0, "y": {}}
694695
td["y"] = {"x": 0, "y": {"x": 0, "y": 42}} # E: Incompatible types (expression has type "int", TypedDict item "y" has type "TD")
695696

0 commit comments

Comments
 (0)