Skip to content

Special case TypedDict.get #19639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mypy.meet import is_overlapping_types
from mypy.messages import MessageBuilder
from mypy.nodes import (
ARG_OPT,
ARG_POS,
ARG_STAR,
ARG_STAR2,
Expand Down Expand Up @@ -68,6 +69,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
Expand Down Expand Up @@ -1400,6 +1402,79 @@ def analyze_typeddict_access(
fallback=mx.chk.named_type("builtins.function"),
name=name,
)
elif name == "get":
# synthesize TypedDict.get() overloads
str_type = mx.chk.named_type("builtins.str")
fn_type = mx.chk.named_type("builtins.function")
object_type = mx.chk.named_type("builtins.object")
t = TypeVarType(
"T",
"T",
id=TypeVarId(-1),
values=[],
upper_bound=object_type,
default=AnyType(TypeOfAny.from_omitted_generics),
)
overloads: list[CallableType] = []
for key, val in typ.items.items():
if key in typ.required_keys:
# If the key is required, we know it must be present in the TypedDict.
overload = CallableType(
arg_types=[LiteralType(key, fallback=str_type), object_type],
arg_kinds=[ARG_POS, ARG_OPT],
arg_names=[None, None],
ret_type=val,
fallback=fn_type,
name=name,
)
overloads.append(overload)
else:
# The key is not required, so we add the overloads:
# def (Literal[Key]) -> Val | None
# def (Literal[Key], default: Val) -> Val
# def [T] (Literal[Key], default: T = ..., /) -> Val | T
# TODO: simplify the last two overloads to just one
overload = CallableType(
arg_types=[LiteralType(key, fallback=str_type)],
arg_kinds=[ARG_POS],
arg_names=[None],
ret_type=UnionType.make_union([val, NoneType()]),
fallback=fn_type,
name=name,
)
overloads.append(overload)
overload = CallableType(
arg_types=[LiteralType(key, fallback=str_type), val],
arg_kinds=[ARG_POS, ARG_POS],
arg_names=[None, None],
ret_type=val,
fallback=fn_type,
name=name,
)
overloads.append(overload)
overload = CallableType(
variables=[t],
arg_types=[LiteralType(key, fallback=str_type), t],
arg_kinds=[ARG_POS, ARG_OPT],
arg_names=[None, None],
ret_type=UnionType.make_union([val, t]),
fallback=fn_type,
name=name,
)
overloads.append(overload)

# finally, add fallback overload when a key is used that is not in the TypedDict
# def (str, object=...) -> object
fallback_overload = CallableType(
arg_types=[str_type, object_type],
arg_kinds=[ARG_POS, ARG_OPT],
arg_names=[None, None],
ret_type=object_type,
fallback=fn_type,
name=name,
)
overloads.append(fallback_overload)
return Overloaded(overloads)
return _analyze_member_access(name, typ.fallback, mx, override_info)


Expand Down
90 changes: 26 additions & 64 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import mypy.errorcodes as codes
from mypy import message_registry
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
from mypy.nodes import IntExpr, StrExpr, UnaryExpr
from mypy.plugin import (
AttributeContext,
ClassDefContext,
Expand Down Expand Up @@ -70,6 +70,7 @@
Instance,
LiteralType,
NoneType,
ProperType,
TupleType,
Type,
TypedDictType,
Expand Down Expand Up @@ -120,9 +121,9 @@ def get_function_signature_hook(
def get_method_signature_hook(
self, fullname: str
) -> Callable[[MethodSigContext], FunctionLike] | None:
if fullname == "typing.Mapping.get":
return typed_dict_get_signature_callback
elif fullname in TD_SETDEFAULT_NAMES:
# NOTE: signatures for `__setitem__`, `__delitem__` and `get` are
# checkmember.py/analyze_typeddict_access
if fullname in TD_SETDEFAULT_NAMES:
return typed_dict_setdefault_signature_callback
elif fullname in TD_POP_NAMES:
return typed_dict_pop_signature_callback
Expand Down Expand Up @@ -212,81 +213,42 @@ def get_class_decorator_hook_2(
return None


def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.get.

This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1
and len(ctx.args[1]) == 1
):
key = ctx.args[0][0].value
value_type = get_proper_type(ctx.type.items.get(key))
ret_type = signature.ret_type
if value_type:
default_arg = ctx.args[1][0]
if (
isinstance(value_type, TypedDictType)
and isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0
):
# Caller has empty dict {} as default for typed dict.
value_type = value_type.copy_modified(required_keys=set())
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = signature.variables[0]
assert isinstance(tv, TypeVarType)
return signature.copy_modified(
arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
ret_type=ret_type,
)
return signature


def typed_dict_get_callback(ctx: MethodContext) -> Type:
"""Infer a precise return type for TypedDict.get with literal first argument."""
"""Infer a precise return type for TypedDict.get with union of literal first argument."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1
and isinstance(get_proper_type(ctx.arg_types[0][0]), UnionType)
):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
return ctx.default_return_type

output_types: list[Type] = []
default_arg: Type

if len(ctx.arg_types) <= 1 or not ctx.arg_types[1]:
default_arg = NoneType()
elif len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
default_arg = ctx.arg_types[1][0]
else:
return ctx.default_return_type

for key in keys:
value_type = get_proper_type(ctx.type.items.get(key))
if value_type is None:
value: ProperType | None = get_proper_type(ctx.type.items.get(key))
if value is None:
return ctx.default_return_type

if len(ctx.arg_types) == 1:
output_types.append(value_type)
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
default_arg = ctx.args[1][0]
if (
isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)
):
# Special case '{}' as the default for a typed dict type.
output_types.append(value_type.copy_modified(required_keys=set()))
else:
output_types.append(value_type)
output_types.append(ctx.arg_types[1][0])

if len(ctx.arg_types) == 1:
output_types.append(NoneType())
if key in ctx.type.required_keys:
output_types.append(value)
else:
output_types.append(value)
output_types.append(default_arg)

# for nicer reveal_type, put default at the end, if it is present
if default_arg in output_types:
output_types = [t for t in output_types if t != default_arg] + [default_arg]
return make_simplified_union(output_types)
return ctx.default_return_type

Expand Down
19 changes: 11 additions & 8 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,7 @@ reveal_type(d[a_key]) # N: Revealed type is "builtins.int"
reveal_type(d[b_key]) # N: Revealed type is "builtins.str"
d[c_key] # E: TypedDict "Outer" has no key "c"

reveal_type(d.get(a_key, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
reveal_type(d.get(a_key, u)) # N: Revealed type is "builtins.int"
reveal_type(d.get(b_key, u)) # N: Revealed type is "Union[builtins.str, __main__.Unrelated]"
reveal_type(d.get(c_key, u)) # N: Revealed type is "builtins.object"

Expand Down Expand Up @@ -1928,7 +1928,7 @@ u: Unrelated
reveal_type(a[int_key_good]) # N: Revealed type is "builtins.int"
reveal_type(b[int_key_good]) # N: Revealed type is "builtins.int"
reveal_type(c[str_key_good]) # N: Revealed type is "builtins.int"
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "builtins.int"
reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "builtins.object"

a[int_key_bad] # E: Tuple index out of range
Expand Down Expand Up @@ -1993,8 +1993,8 @@ optional_keys: Literal["d", "e"]
bad_keys: Literal["a", "bad"]

reveal_type(test[good_keys]) # N: Revealed type is "Union[__main__.A, __main__.B]"
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B, None]"
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?, __main__.B]"
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B]"
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B]"
reveal_type(test.pop(optional_keys)) # N: Revealed type is "Union[__main__.D, __main__.E]"
reveal_type(test.pop(optional_keys, 3)) # N: Revealed type is "Union[__main__.D, __main__.E, Literal[3]?]"
reveal_type(test.setdefault(good_keys, AAndB())) # N: Revealed type is "Union[__main__.A, __main__.B]"
Expand Down Expand Up @@ -2037,15 +2037,18 @@ class D2(TypedDict):
d: D

x: Union[D1, D2]
bad_keys: Literal['a', 'b', 'c', 'd']
good_keys: Literal['b', 'c']
mixed_keys: Literal['a', 'b', 'c', 'd']
bad_keys: Literal['e', 'f']

x[bad_keys] # E: TypedDict "D1" has no key "d" \
x[mixed_keys] # E: TypedDict "D1" has no key "d" \
# E: TypedDict "D2" has no key "a"

reveal_type(x[good_keys]) # N: Revealed type is "Union[__main__.B, __main__.C]"
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C, None]"
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, Literal[3]?, __main__.C]"
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C]"
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, __main__.C]"
reveal_type(x.get(mixed_keys)) # N: Revealed type is "builtins.object"
reveal_type(x.get(mixed_keys, 3)) # N: Revealed type is "builtins.object"
reveal_type(x.get(bad_keys)) # N: Revealed type is "builtins.object"
reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "builtins.object"

Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-recursive-types.test
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ td: TD
td["y"] = {"x": 0, "y": {}}
td["y"] = {"x": 0, "y": {"x": 0, "y": 42}} # E: Incompatible types (expression has type "int", TypedDict item "y" has type "TD")

reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...})}), None]"
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...}), None]"
s: str = td.get("y") # E: Incompatible types in assignment (expression has type "Optional[TD]", variable has type "str")

td.update({"x": 0, "y": {"x": 1, "y": {}}})
Expand Down
Loading
Loading