Skip to content

Commit 6d92f95

Browse files
committed
feat: proper narrowing for TypedDict keys and values
1 parent db67fac commit 6d92f95

File tree

4 files changed

+49
-10
lines changed

4 files changed

+49
-10
lines changed

mypy/checkmember.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1360,7 +1360,29 @@ def analyze_enum_class_attribute_access(
13601360
def analyze_typeddict_access(
13611361
name: str, typ: TypedDictType, mx: MemberContext, override_info: TypeInfo | None
13621362
) -> Type:
1363-
if name == "__setitem__":
1363+
if name == "keys":
1364+
# Return KeysView[union of Literal key types]
1365+
keys_view_info = mx.chk.named_type("typing.KeysView").type
1366+
return CallableType(
1367+
arg_types=[],
1368+
arg_kinds=[],
1369+
arg_names=[],
1370+
ret_type=Instance(keys_view_info, [typ.key_type]),
1371+
fallback=mx.chk.named_type("builtins.function"),
1372+
name=name,
1373+
)
1374+
elif name == "values":
1375+
# Return ValuesView[union of value types]
1376+
values_view_info = mx.chk.named_type("typing.ValuesView").type
1377+
return CallableType(
1378+
arg_types=[],
1379+
arg_kinds=[],
1380+
arg_names=[],
1381+
ret_type=Instance(values_view_info, [typ.value_type]),
1382+
fallback=mx.chk.named_type("builtins.function"),
1383+
name=name,
1384+
)
1385+
elif name == "__setitem__":
13641386
if isinstance(mx.context, IndexExpr):
13651387
# Since we can get this during `a['key'] = ...`
13661388
# it is safe to assume that the context is `IndexExpr`.

mypy/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2584,6 +2584,18 @@ def __init__(
25842584
self.extra_items_from = []
25852585
self.to_be_mutated = False
25862586

2587+
@property
2588+
def key_type(self) -> Type:
2589+
"""Return a Union of Literal types for all keys."""
2590+
return UnionType.make_union(
2591+
[LiteralType(key, self.fallback) for key in self.items.keys()]
2592+
)
2593+
2594+
@property
2595+
def value_type(self) -> Type:
2596+
"""Return a Union of all value types (deduplicated)."""
2597+
return UnionType.make_union(list({get_proper_type(typ) for typ in self.items.values()}))
2598+
25872599
def accept(self, visitor: TypeVisitor[T]) -> T:
25882600
return visitor.visit_typeddict_type(self)
25892601

test-data/unit/check-typeddict.test

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ from typing import TypedDict
55
Point = TypedDict('Point', {'x': int, 'y': int})
66
p = Point(x=42, y=1337)
77
reveal_type(p) # N: Revealed type is "TypedDict('__main__.Point', {'x': builtins.int, 'y': builtins.int})"
8-
# Use values() to check fallback value type.
9-
reveal_type(p.values()) # N: Revealed type is "typing.Iterable[builtins.object]"
8+
reveal_type(p.values()) # N: Revealed type is "typing.ValuesView[builtins.int]"
9+
reveal_type(p.keys()) # N: Revealed type is "typing.KeysView[Union[Literal['x'], Literal['y']]]"
1010
[builtins fixtures/dict.pyi]
1111
[typing fixtures/typing-typeddict.pyi]
1212
[targets __main__]
@@ -16,8 +16,8 @@ from typing import TypedDict
1616
Point = TypedDict('Point', {'x': int, 'y': int})
1717
p = Point(dict(x=42, y=1337))
1818
reveal_type(p) # N: Revealed type is "TypedDict('__main__.Point', {'x': builtins.int, 'y': builtins.int})"
19-
# Use values() to check fallback value type.
20-
reveal_type(p.values()) # N: Revealed type is "typing.Iterable[builtins.object]"
19+
reveal_type(p.values()) # N: Revealed type is "typing.ValuesView[builtins.int]"
20+
reveal_type(p.keys()) # N: Revealed type is "typing.KeysView[Union[Literal['x'], Literal['y']]]"
2121
[builtins fixtures/dict.pyi]
2222
[typing fixtures/typing-typeddict.pyi]
2323

@@ -26,8 +26,8 @@ from typing import TypedDict
2626
Point = TypedDict('Point', {'x': int, 'y': int})
2727
p = Point({'x': 42, 'y': 1337})
2828
reveal_type(p) # N: Revealed type is "TypedDict('__main__.Point', {'x': builtins.int, 'y': builtins.int})"
29-
# Use values() to check fallback value type.
30-
reveal_type(p.values()) # N: Revealed type is "typing.Iterable[builtins.object]"
29+
reveal_type(p.values()) # N: Revealed type is "typing.ValuesView[builtins.int]"
30+
reveal_type(p.keys()) # N: Revealed type is "typing.KeysView[Union[Literal['x'], Literal['y']]]"
3131
[builtins fixtures/dict.pyi]
3232
[typing fixtures/typing-typeddict.pyi]
3333

@@ -36,7 +36,8 @@ from typing import TypedDict, TypeVar, Union
3636
EmptyDict = TypedDict('EmptyDict', {})
3737
p = EmptyDict()
3838
reveal_type(p) # N: Revealed type is "TypedDict('__main__.EmptyDict', {})"
39-
reveal_type(p.values()) # N: Revealed type is "typing.Iterable[builtins.object]"
39+
reveal_type(p.values()) # N: Revealed type is "typing.ValuesView[Never]"
40+
reveal_type(p.keys()) # N: Revealed type is "typing.KeysView[Never]"
4041
[builtins fixtures/dict.pyi]
4142
[typing fixtures/typing-typeddict.pyi]
4243

@@ -534,8 +535,8 @@ Point3D = TypedDict('Point3D', {'x': int, 'y': int, 'z': int})
534535
p1 = TaggedPoint(type='2d', x=0, y=0)
535536
p2 = Point3D(x=1, y=1, z=1)
536537
joined_points = [p1, p2][0]
537-
reveal_type(p1.values()) # N: Revealed type is "typing.Iterable[builtins.object]"
538-
reveal_type(p2.values()) # N: Revealed type is "typing.Iterable[builtins.object]"
538+
reveal_type(p1.values()) # N: Revealed type is "typing.ValuesView[Union[builtins.str, builtins.int]]"
539+
reveal_type(p2.values()) # N: Revealed type is "typing.ValuesView[builtins.int]"
539540
reveal_type(joined_points) # N: Revealed type is "TypedDict({'x': builtins.int, 'y': builtins.int})"
540541
[builtins fixtures/dict.pyi]
541542
[typing fixtures/typing-typeddict.pyi]

test-data/unit/fixtures/typing-typeddict.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,7 @@ class _TypedDict(Mapping[str, object]):
7979
def __delitem__(self, k: NoReturn) -> None: ...
8080

8181
class _SpecialForm: pass
82+
83+
class KeysView(Iterable[T]): pass
84+
85+
class ValuesView(Iterable[V]): pass

0 commit comments

Comments
 (0)