Skip to content

Commit 63f4975

Browse files
revert reveal_type changes
1 parent cbbbd4d commit 63f4975

File tree

6 files changed

+77
-101
lines changed

6 files changed

+77
-101
lines changed

mypy/checkexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1502,7 +1502,7 @@ def check_call_expr_with_callee_type(
15021502
def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
15031503
"""Type check calling a member expression where the base type is a union."""
15041504
res: list[Type] = []
1505-
for typ in object_type.relevant_items():
1505+
for typ in flatten_nested_unions(object_type.relevant_items()):
15061506
# Member access errors are already reported when visiting the member expression.
15071507
with self.msg.filter_errors():
15081508
item = analyze_member_access(

mypy/checkmember.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from mypy.meet import is_overlapping_types
1818
from mypy.messages import MessageBuilder
1919
from mypy.nodes import (
20-
ARG_OPT,
2120
ARG_POS,
2221
ARG_STAR,
2322
ARG_STAR2,
@@ -69,7 +68,6 @@
6968
TypedDictType,
7069
TypeOfAny,
7170
TypeType,
72-
TypeVarId,
7371
TypeVarLikeType,
7472
TypeVarTupleType,
7573
TypeVarType,
@@ -1406,76 +1404,6 @@ def analyze_typeddict_access(
14061404
fallback=mx.chk.named_type("builtins.function"),
14071405
name=name,
14081406
)
1409-
elif name == "get":
1410-
# synthesize TypedDict.get() overloads
1411-
str_type = mx.chk.named_type("builtins.str")
1412-
fn_type = mx.chk.named_type("builtins.function")
1413-
object_type = mx.chk.named_type("builtins.object")
1414-
type_info = typ.fallback.type
1415-
# type variable for default value
1416-
tvar = TypeVarType(
1417-
"T",
1418-
f"{type_info.fullname}.get.T",
1419-
id=TypeVarId(-1, namespace=f"{type_info.fullname}.get"),
1420-
values=[],
1421-
upper_bound=object_type,
1422-
default=AnyType(TypeOfAny.from_omitted_generics),
1423-
)
1424-
# generate the overloads
1425-
overloads: list[CallableType] = []
1426-
for key, value_type in typ.items.items():
1427-
key_type = LiteralType(key, fallback=str_type)
1428-
1429-
if key in typ.required_keys:
1430-
# If the key is required, we know it must be present in the TypedDict.
1431-
# def (K, object=...) -> V
1432-
overload = CallableType(
1433-
arg_types=[key_type, object_type],
1434-
arg_kinds=[ARG_POS, ARG_OPT],
1435-
arg_names=[None, None],
1436-
ret_type=value_type,
1437-
fallback=fn_type,
1438-
name=name,
1439-
)
1440-
overloads.append(overload)
1441-
else:
1442-
# The key is not required, but if it is present, we know its type.
1443-
# def (K) -> V | None (implicit default)
1444-
overload = CallableType(
1445-
arg_types=[key_type],
1446-
arg_kinds=[ARG_POS],
1447-
arg_names=[None],
1448-
ret_type=UnionType.make_union([value_type, NoneType()]),
1449-
fallback=fn_type,
1450-
name=name,
1451-
)
1452-
overloads.append(overload)
1453-
1454-
# def [T](K, T) -> V | T (explicit default)
1455-
overload = CallableType(
1456-
variables=[tvar],
1457-
arg_types=[key_type, tvar],
1458-
arg_kinds=[ARG_POS, ARG_POS],
1459-
arg_names=[None, None],
1460-
ret_type=UnionType.make_union([value_type, tvar]),
1461-
fallback=fn_type,
1462-
name=name,
1463-
)
1464-
overloads.append(overload)
1465-
1466-
# finally, add fallback overload when a key is used that is not in the TypedDict
1467-
# TODO: add support for extra items (PEP 728)
1468-
# def (str, object=...) -> object
1469-
fallback_overload = CallableType(
1470-
arg_types=[str_type, object_type],
1471-
arg_kinds=[ARG_POS, ARG_OPT],
1472-
arg_names=[None, None],
1473-
ret_type=object_type,
1474-
fallback=fn_type,
1475-
name=name,
1476-
)
1477-
overloads.append(fallback_overload)
1478-
return Overloaded(overloads)
14791407
return _analyze_member_access(name, typ.fallback, mx, override_info)
14801408

14811409

mypy/plugins/default.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def get_function_signature_hook(
120120
def get_method_signature_hook(
121121
self, fullname: str
122122
) -> Callable[[MethodSigContext], FunctionLike] | None:
123-
# NOTE: signatures for `__setitem__`, `__delitem__` and `get` are
124-
# defined in checkmember.py/analyze_typeddict_access
125-
if fullname in TD_SETDEFAULT_NAMES:
123+
if fullname == "typing.Mapping.get":
124+
return typed_dict_get_signature_callback
125+
elif fullname in TD_SETDEFAULT_NAMES:
126126
return typed_dict_setdefault_signature_callback
127127
elif fullname in TD_POP_NAMES:
128128
return typed_dict_pop_signature_callback
@@ -212,6 +212,46 @@ def get_class_decorator_hook_2(
212212
return None
213213

214214

215+
def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
216+
"""Try to infer a better signature type for TypedDict.get.
217+
218+
This is used to get better type context for the second argument that
219+
depends on a TypedDict value type.
220+
"""
221+
signature = ctx.default_signature
222+
if (
223+
isinstance(ctx.type, TypedDictType)
224+
and len(ctx.args) == 2
225+
and len(ctx.args[0]) == 1
226+
and isinstance(ctx.args[0][0], StrExpr)
227+
and len(signature.arg_types) == 2
228+
and len(signature.variables) == 1
229+
and len(ctx.args[1]) == 1
230+
):
231+
key = ctx.args[0][0].value
232+
value_type = get_proper_type(ctx.type.items.get(key))
233+
ret_type = signature.ret_type
234+
if value_type:
235+
default_arg = ctx.args[1][0]
236+
if (
237+
isinstance(value_type, TypedDictType)
238+
and isinstance(default_arg, DictExpr)
239+
and len(default_arg.items) == 0
240+
):
241+
# Caller has empty dict {} as default for typed dict.
242+
value_type = value_type.copy_modified(required_keys=set())
243+
# Tweak the signature to include the value type as context. It's
244+
# only needed for type inference since there's a union with a type
245+
# variable that accepts everything.
246+
tv = signature.variables[0]
247+
assert isinstance(tv, TypeVarType)
248+
return signature.copy_modified(
249+
arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
250+
ret_type=ret_type,
251+
)
252+
return signature
253+
254+
215255
def typed_dict_get_callback(ctx: MethodContext) -> Type:
216256
"""Infer a precise return type for TypedDict.get with literal first argument."""
217257
if (

test-data/unit/check-typeddict.test

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,12 +1007,12 @@ class D(TypedDict):
10071007
b: NotRequired[str]
10081008

10091009
def test(d: D) -> None:
1010-
reveal_type(d.get) # N: Revealed type is \
1011-
"Overload(\
1012-
def (Literal['a'], builtins.object =) -> builtins.int, \
1013-
def (Literal['b']) -> Union[builtins.str, None], \
1014-
def [T] (Literal['b'], T`-1) -> Union[builtins.str, T`-1], \
1015-
def (builtins.str, builtins.object =) -> builtins.object)"
1010+
reveal_type(d.get) # N: Revealed type is "Overload(def (k: builtins.str) -> builtins.object, def (builtins.str, builtins.object) -> builtins.object, def [V] (builtins.str, V`4) -> builtins.object)"
1011+
1012+
1013+
1014+
1015+
10161016

10171017
[builtins fixtures/dict.pyi]
10181018
[typing fixtures/typing-typeddict.pyi]
@@ -1155,21 +1155,29 @@ reveal_type(d.get('x', a)) # N: Revealed type is "Union[builtins.list[builtins.i
11551155
from typing import TypedDict
11561156
D = TypedDict('D', {'x': int, 'y': str})
11571157
d: D
1158-
d.get() # E: All overload variants of "get" require at least one argument \
1158+
d.get() # E: All overload variants of "get" of "Mapping" require at least one argument \
11591159
# N: Possible overload variants: \
1160-
# N: def get(Literal['x'], object = ..., /) -> int \
1161-
# N: def get(Literal['y'], object = ..., /) -> str \
1162-
# N: def get(str, object = ..., /) -> object
1163-
d.get('x', 1, 2) # E: No overload variant of "get" matches argument types "str", "int", "int" \
1160+
# N: def get(self, k: str) -> object \
1161+
# N: def get(self, str, object, /) -> object \
1162+
# N: def [V] get(self, str, V, /) -> object
1163+
d.get('x', 1, 2) # E: No overload variant of "get" of "Mapping" matches argument types "str", "int", "int" \
11641164
# N: Possible overload variants: \
1165-
# N: def get(Literal['x'], object = ..., /) -> int \
1166-
# N: def get(Literal['y'], object = ..., /) -> str \
1167-
# N: def get(str, object = ..., /) -> object
1165+
# N: def get(self, k: str) -> object \
1166+
# N: def get(self, str, object, /) -> object \
1167+
# N: def [V] get(self, str, Union[int, V], /) -> object
11681168
x = d.get('z')
11691169
reveal_type(x) # N: Revealed type is "builtins.object"
11701170
s = ''
11711171
y = d.get(s)
11721172
reveal_type(y) # N: Revealed type is "builtins.object"
1173+
1174+
1175+
1176+
1177+
1178+
1179+
1180+
11731181
[builtins fixtures/dict.pyi]
11741182
[typing fixtures/typing-typeddict.pyi]
11751183

test-data/unit/fixtures/typing-typeddict.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta):
5656
@overload
5757
def get(self, k: T) -> Optional[T_co]: pass
5858
@overload
59-
def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass
59+
def get(self, k: T, default: T_co, /) -> Optional[T_co]: pass # type: ignore[misc]
60+
@overload
61+
def get(self, k: T, default: V, /) -> Union[T_co, V]: pass
6062
def values(self) -> Iterable[T_co]: pass # Approximate return type
6163
def __len__(self) -> int: ...
6264
def __contains__(self, arg: object) -> int: pass

test-data/unit/pythoneval.test

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,22 +1056,20 @@ def test_not_total(d: D_not_total) -> None:
10561056
_testTypedDictGet.py:8: note: Revealed type is "builtins.int"
10571057
_testTypedDictGet.py:9: note: Revealed type is "builtins.str"
10581058
_testTypedDictGet.py:10: note: Revealed type is "builtins.object"
1059-
_testTypedDictGet.py:11: error: All overload variants of "get" require at least one argument
1059+
_testTypedDictGet.py:11: error: All overload variants of "get" of "Mapping" require at least one argument
10601060
_testTypedDictGet.py:11: note: Possible overload variants:
1061-
_testTypedDictGet.py:11: note: def get(Literal['x'], object = ..., /) -> int
1062-
_testTypedDictGet.py:11: note: def get(Literal['y'], object = ..., /) -> str
1063-
_testTypedDictGet.py:11: note: def get(str, object = ..., /) -> object
1061+
_testTypedDictGet.py:11: note: def get(self, str, /) -> object
1062+
_testTypedDictGet.py:11: note: def get(self, str, /, default: object) -> object
1063+
_testTypedDictGet.py:11: note: def [_T] get(self, str, /, default: _T) -> object
10641064
_testTypedDictGet.py:13: note: Revealed type is "builtins.object"
10651065
_testTypedDictGet.py:16: note: Revealed type is "Union[builtins.int, None]"
10661066
_testTypedDictGet.py:17: note: Revealed type is "Union[builtins.str, None]"
10671067
_testTypedDictGet.py:18: note: Revealed type is "builtins.object"
1068-
_testTypedDictGet.py:19: error: All overload variants of "get" require at least one argument
1068+
_testTypedDictGet.py:19: error: All overload variants of "get" of "Mapping" require at least one argument
10691069
_testTypedDictGet.py:19: note: Possible overload variants:
1070-
_testTypedDictGet.py:19: note: def get(Literal['x'], /) -> Optional[int]
1071-
_testTypedDictGet.py:19: note: def [T] get(Literal['x'], T, /) -> Union[int, T]
1072-
_testTypedDictGet.py:19: note: def get(Literal['y'], /) -> Optional[str]
1073-
_testTypedDictGet.py:19: note: def [T] get(Literal['y'], T, /) -> Union[str, T]
1074-
_testTypedDictGet.py:19: note: def get(str, object = ..., /) -> object
1070+
_testTypedDictGet.py:19: note: def get(self, str, /) -> object
1071+
_testTypedDictGet.py:19: note: def get(self, str, /, default: object) -> object
1072+
_testTypedDictGet.py:19: note: def [_T] get(self, str, /, default: _T) -> object
10751073
_testTypedDictGet.py:21: note: Revealed type is "builtins.object"
10761074

10771075
[case testTypedDictMappingMethods]

0 commit comments

Comments
 (0)