Skip to content

Commit fb01cc4

Browse files
updated plugin implementation
1 parent 4f1f7a9 commit fb01cc4

File tree

6 files changed

+220
-177
lines changed

6 files changed

+220
-177
lines changed

mypy/checkmember.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mypy.meet import is_overlapping_types
1818
from mypy.messages import MessageBuilder
1919
from mypy.nodes import (
20+
ARG_OPT,
2021
ARG_POS,
2122
ARG_STAR,
2223
ARG_STAR2,
@@ -1403,64 +1404,76 @@ def analyze_typeddict_access(
14031404
)
14041405
elif name == "get":
14051406
# synthesize TypedDict.get() overloads
1407+
str_type = mx.chk.named_type("builtins.str")
1408+
fn_type = mx.chk.named_type("builtins.function")
1409+
object_type = mx.chk.named_type("builtins.object")
14061410
t = TypeVarType(
14071411
"T",
14081412
"T",
14091413
id=TypeVarId(-1),
14101414
values=[],
1411-
upper_bound=mx.chk.named_type("builtins.object"),
1412-
default=AnyType(TypeOfAny.from_omitted_generics),
1415+
upper_bound=object_type,
1416+
default=UninhabitedType(),
14131417
)
1414-
str_type = mx.chk.named_type("builtins.str")
1415-
fn_type = mx.chk.named_type("builtins.function")
1416-
object_type = mx.chk.named_type("builtins.object")
1417-
14181418
overloads: list[CallableType] = []
1419-
# add two overloads per TypedDictType spec
14201419
for key, val in typ.items.items():
1421-
# first overload: def(Literal[key]) -> val
1422-
no_default = CallableType(
1423-
arg_types=[LiteralType(key, fallback=str_type)],
1424-
arg_kinds=[ARG_POS],
1425-
arg_names=[None],
1426-
ret_type=val,
1427-
fallback=fn_type,
1428-
name=name,
1429-
)
1430-
# second Overload: def [T] (Literal[key], default: T | Val, /) -> T | Val
1431-
with_default = CallableType(
1432-
variables=[t],
1433-
arg_types=[LiteralType(key, fallback=str_type), UnionType.make_union([val, t])],
1434-
arg_kinds=[ARG_POS, ARG_POS],
1435-
arg_names=[None, None],
1436-
ret_type=UnionType.make_union([val, t]),
1437-
fallback=fn_type,
1438-
name=name,
1439-
)
1440-
overloads.append(no_default)
1441-
overloads.append(with_default)
1420+
if key in typ.required_keys:
1421+
# If the key is required, we know it must be present in the TypedDict.
1422+
overload = CallableType(
1423+
arg_types=[LiteralType(key, fallback=str_type), object_type],
1424+
arg_kinds=[ARG_POS, ARG_OPT],
1425+
arg_names=[None, None],
1426+
ret_type=val,
1427+
fallback=fn_type,
1428+
name=name,
1429+
)
1430+
overloads.append(overload)
1431+
else:
1432+
# The key is not required, so we add the overloads:
1433+
# def (Literal[Key]) -> Val | None
1434+
# def (Literal[Key], default: Val) -> Val
1435+
# def [T] (Literal[Key], default: T = ..., /) -> Val | T
1436+
# TODO: simplify the last two overloads to just one
1437+
overload = CallableType(
1438+
arg_types=[LiteralType(key, fallback=str_type)],
1439+
arg_kinds=[ARG_POS],
1440+
arg_names=[None],
1441+
ret_type=UnionType.make_union([val, NoneType()]),
1442+
fallback=fn_type,
1443+
name=name,
1444+
)
1445+
overloads.append(overload)
1446+
overload = CallableType(
1447+
arg_types=[LiteralType(key, fallback=str_type), val],
1448+
arg_kinds=[ARG_POS, ARG_POS],
1449+
arg_names=[None, None],
1450+
ret_type=val,
1451+
fallback=fn_type,
1452+
name=name,
1453+
)
1454+
overloads.append(overload)
1455+
overload = CallableType(
1456+
variables=[t],
1457+
arg_types=[LiteralType(key, fallback=str_type), t],
1458+
arg_kinds=[ARG_POS, ARG_OPT],
1459+
arg_names=[None, None],
1460+
ret_type=UnionType.make_union([val, t]),
1461+
fallback=fn_type,
1462+
name=name,
1463+
)
1464+
overloads.append(overload)
14421465

1443-
# finally, add fallback overloads when a key is used that is not in the TypedDict
1444-
# def (str) -> object
1445-
fallback_no_default = CallableType(
1446-
arg_types=[str_type],
1447-
arg_kinds=[ARG_POS],
1448-
arg_names=[None],
1449-
ret_type=object_type,
1450-
fallback=fn_type,
1451-
name=name,
1452-
)
1453-
# def (str, object) -> object
1454-
fallback_with_default = CallableType(
1466+
# finally, add fallback overload when a key is used that is not in the TypedDict
1467+
# def (str, object=...) -> object
1468+
fallback_overload = CallableType(
14551469
arg_types=[str_type, object_type],
1456-
arg_kinds=[ARG_POS, ARG_POS],
1470+
arg_kinds=[ARG_POS, ARG_OPT],
14571471
arg_names=[None, None],
14581472
ret_type=object_type,
14591473
fallback=fn_type,
14601474
name=name,
14611475
)
1462-
overloads.append(fallback_no_default)
1463-
overloads.append(fallback_with_default)
1476+
overloads.append(fallback_overload)
14641477
return Overloaded(overloads)
14651478
return _analyze_member_access(name, typ.fallback, mx, override_info)
14661479

mypy/plugins/default.py

Lines changed: 26 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import mypy.errorcodes as codes
77
from mypy import message_registry
8-
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
8+
from mypy.nodes import IntExpr, StrExpr, UnaryExpr
99
from mypy.plugin import (
1010
AttributeContext,
1111
ClassDefContext,
@@ -70,6 +70,7 @@
7070
Instance,
7171
LiteralType,
7272
NoneType,
73+
ProperType,
7374
TupleType,
7475
Type,
7576
TypedDictType,
@@ -120,9 +121,9 @@ def get_function_signature_hook(
120121
def get_method_signature_hook(
121122
self, fullname: str
122123
) -> Callable[[MethodSigContext], FunctionLike] | None:
123-
if fullname == "typing.Mapping.get":
124-
return typed_dict_get_signature_callback
125-
elif fullname in TD_SETDEFAULT_NAMES:
124+
# NOTE: signatures for `__setitem__`, `__delitem__` and `get` are
125+
# checkmember.py/analyze_typeddict_access
126+
if fullname in TD_SETDEFAULT_NAMES:
126127
return typed_dict_setdefault_signature_callback
127128
elif fullname in TD_POP_NAMES:
128129
return typed_dict_pop_signature_callback
@@ -212,81 +213,42 @@ def get_class_decorator_hook_2(
212213
return None
213214

214215

215-
def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
216-
"""Try to infer a better signature type for TypedDict.get.
217-
218-
This is used to get better type context for the second argument that
219-
depends on a TypedDict value type.
220-
"""
221-
signature = ctx.default_signature
222-
if (
223-
isinstance(ctx.type, TypedDictType)
224-
and len(ctx.args) == 2
225-
and len(ctx.args[0]) == 1
226-
and isinstance(ctx.args[0][0], StrExpr)
227-
and len(signature.arg_types) == 2
228-
and len(signature.variables) == 1
229-
and len(ctx.args[1]) == 1
230-
):
231-
key = ctx.args[0][0].value
232-
value_type = get_proper_type(ctx.type.items.get(key))
233-
ret_type = signature.ret_type
234-
if value_type:
235-
default_arg = ctx.args[1][0]
236-
if (
237-
isinstance(value_type, TypedDictType)
238-
and isinstance(default_arg, DictExpr)
239-
and len(default_arg.items) == 0
240-
):
241-
# Caller has empty dict {} as default for typed dict.
242-
value_type = value_type.copy_modified(required_keys=set())
243-
# Tweak the signature to include the value type as context. It's
244-
# only needed for type inference since there's a union with a type
245-
# variable that accepts everything.
246-
tv = signature.variables[0]
247-
assert isinstance(tv, TypeVarType)
248-
return signature.copy_modified(
249-
arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
250-
ret_type=ret_type,
251-
)
252-
return signature
253-
254-
255216
def typed_dict_get_callback(ctx: MethodContext) -> Type:
256-
"""Infer a precise return type for TypedDict.get with literal first argument."""
217+
"""Infer a precise return type for TypedDict.get with union of literal first argument."""
257218
if (
258219
isinstance(ctx.type, TypedDictType)
259220
and len(ctx.arg_types) >= 1
260221
and len(ctx.arg_types[0]) == 1
222+
and isinstance(get_proper_type(ctx.arg_types[0][0]), UnionType)
261223
):
262224
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
263225
if keys is None:
264226
return ctx.default_return_type
265227

266228
output_types: list[Type] = []
229+
default_arg: Type
230+
231+
if len(ctx.arg_types) <= 1 or not ctx.arg_types[1]:
232+
default_arg = NoneType()
233+
elif len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
234+
default_arg = ctx.arg_types[1][0]
235+
else:
236+
return ctx.default_return_type
237+
267238
for key in keys:
268-
value_type = get_proper_type(ctx.type.items.get(key))
269-
if value_type is None:
239+
value: ProperType | None = get_proper_type(ctx.type.items.get(key))
240+
if value is None:
270241
return ctx.default_return_type
271242

272-
if len(ctx.arg_types) == 1:
273-
output_types.append(value_type)
274-
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
275-
default_arg = ctx.args[1][0]
276-
if (
277-
isinstance(default_arg, DictExpr)
278-
and len(default_arg.items) == 0
279-
and isinstance(value_type, TypedDictType)
280-
):
281-
# Special case '{}' as the default for a typed dict type.
282-
output_types.append(value_type.copy_modified(required_keys=set()))
283-
else:
284-
output_types.append(value_type)
285-
output_types.append(ctx.arg_types[1][0])
286-
287-
if len(ctx.arg_types) == 1:
288-
output_types.append(NoneType())
243+
if key in ctx.type.required_keys:
244+
output_types.append(value)
245+
else:
246+
output_types.append(value)
247+
output_types.append(default_arg)
289248

249+
# for nicer reveal_type, put default at the end, if it is present
250+
if default_arg in output_types:
251+
output_types = [t for t in output_types if t != default_arg] + [default_arg]
290252
return make_simplified_union(output_types)
291253
return ctx.default_return_type
292254

test-data/unit/check-literal.test

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ reveal_type(d[a_key]) # N: Revealed type is "builtins.int"
18841884
reveal_type(d[b_key]) # N: Revealed type is "builtins.str"
18851885
d[c_key] # E: TypedDict "Outer" has no key "c"
18861886

1887-
reveal_type(d.get(a_key, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
1887+
reveal_type(d.get(a_key, u)) # N: Revealed type is "builtins.int"
18881888
reveal_type(d.get(b_key, u)) # N: Revealed type is "Union[builtins.str, __main__.Unrelated]"
18891889
reveal_type(d.get(c_key, u)) # N: Revealed type is "builtins.object"
18901890

@@ -1928,7 +1928,7 @@ u: Unrelated
19281928
reveal_type(a[int_key_good]) # N: Revealed type is "builtins.int"
19291929
reveal_type(b[int_key_good]) # N: Revealed type is "builtins.int"
19301930
reveal_type(c[str_key_good]) # N: Revealed type is "builtins.int"
1931-
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "Union[builtins.int, __main__.Unrelated]"
1931+
reveal_type(c.get(str_key_good, u)) # N: Revealed type is "builtins.int"
19321932
reveal_type(c.get(str_key_bad, u)) # N: Revealed type is "builtins.object"
19331933

19341934
a[int_key_bad] # E: Tuple index out of range
@@ -1993,8 +1993,8 @@ optional_keys: Literal["d", "e"]
19931993
bad_keys: Literal["a", "bad"]
19941994

19951995
reveal_type(test[good_keys]) # N: Revealed type is "Union[__main__.A, __main__.B]"
1996-
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B, None]"
1997-
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, Literal[3]?, __main__.B]"
1996+
reveal_type(test.get(good_keys)) # N: Revealed type is "Union[__main__.A, __main__.B]"
1997+
reveal_type(test.get(good_keys, 3)) # N: Revealed type is "Union[__main__.A, __main__.B]"
19981998
reveal_type(test.pop(optional_keys)) # N: Revealed type is "Union[__main__.D, __main__.E]"
19991999
reveal_type(test.pop(optional_keys, 3)) # N: Revealed type is "Union[__main__.D, __main__.E, Literal[3]?]"
20002000
reveal_type(test.setdefault(good_keys, AAndB())) # N: Revealed type is "Union[__main__.A, __main__.B]"
@@ -2037,15 +2037,18 @@ class D2(TypedDict):
20372037
d: D
20382038

20392039
x: Union[D1, D2]
2040-
bad_keys: Literal['a', 'b', 'c', 'd']
20412040
good_keys: Literal['b', 'c']
2041+
mixed_keys: Literal['a', 'b', 'c', 'd']
2042+
bad_keys: Literal['e', 'f']
20422043

2043-
x[bad_keys] # E: TypedDict "D1" has no key "d" \
2044+
x[mixed_keys] # E: TypedDict "D1" has no key "d" \
20442045
# E: TypedDict "D2" has no key "a"
20452046

20462047
reveal_type(x[good_keys]) # N: Revealed type is "Union[__main__.B, __main__.C]"
2047-
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C, None]"
2048-
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, Literal[3]?, __main__.C]"
2048+
reveal_type(x.get(good_keys)) # N: Revealed type is "Union[__main__.B, __main__.C]"
2049+
reveal_type(x.get(good_keys, 3)) # N: Revealed type is "Union[__main__.B, __main__.C]"
2050+
reveal_type(x.get(mixed_keys)) # N: Revealed type is "builtins.object"
2051+
reveal_type(x.get(mixed_keys, 3)) # N: Revealed type is "builtins.object"
20492052
reveal_type(x.get(bad_keys)) # N: Revealed type is "builtins.object"
20502053
reveal_type(x.get(bad_keys, 3)) # N: Revealed type is "builtins.object"
20512054

test-data/unit/check-recursive-types.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ td: TD
693693
td["y"] = {"x": 0, "y": {}}
694694
td["y"] = {"x": 0, "y": {"x": 0, "y": 42}} # E: Incompatible types (expression has type "int", TypedDict item "y" has type "TD")
695695

696-
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...})}), None]"
696+
reveal_type(td.get("y")) # N: Revealed type is "Union[TypedDict('__main__.TD', {'x'?: builtins.int, 'y'?: ...}), None]"
697697
s: str = td.get("y") # E: Incompatible types in assignment (expression has type "Optional[TD]", variable has type "str")
698698

699699
td.update({"x": 0, "y": {"x": 1, "y": {}}})

0 commit comments

Comments
 (0)