Skip to content

Commit 2b630bf

Browse files
committed
Add basic support for recussive TypeVar defaults (PEP 696)
1 parent a61698b commit 2b630bf

File tree

6 files changed

+91
-4
lines changed

6 files changed

+91
-4
lines changed

mypy/applytype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def apply_generic_arguments(
147147
# TODO: move apply_poly() logic from checkexpr.py here when new inference
148148
# becomes universally used (i.e. in all passes + in unification).
149149
# With this new logic we can actually *add* some new free variables.
150-
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
150+
remaining_tvars = [expand_type(tv, id_to_type) for tv in tvars if tv.id not in id_to_type]
151151

152152
return callable.copy_modified(
153153
ret_type=expand_type(callable.ret_type, id_to_type),

mypy/expandtype.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
179179

180180
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
181181
self.variables = variables
182+
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}
182183

183184
def visit_unbound_type(self, t: UnboundType) -> Type:
184185
return t
@@ -226,6 +227,14 @@ def visit_type_var(self, t: TypeVarType) -> Type:
226227
# TODO: do we really need to do this?
227228
# If I try to remove this special-casing ~40 tests fail on reveal_type().
228229
return repl.copy_modified(last_known_value=None)
230+
if isinstance(repl, TypeVarType) and repl.has_default():
231+
if (tvar_id := repl.id) in self.recursive_tvar_guard:
232+
return self.recursive_tvar_guard[tvar_id] or repl
233+
self.recursive_tvar_guard[tvar_id] = None
234+
repl = repl.accept(self)
235+
if isinstance(repl, TypeVarType):
236+
repl.default = repl.default.accept(self)
237+
self.recursive_tvar_guard[tvar_id] = repl
229238
return repl
230239

231240
def visit_param_spec(self, t: ParamSpecType) -> Type:

mypy/semanal.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,14 @@ class Foo(Bar, Generic[T]): ...
19541954
del base_type_exprs[i]
19551955
tvar_defs: list[TypeVarLikeType] = []
19561956
for name, tvar_expr in declared_tvars:
1957+
if isinstance(tvar_expr.default, UnboundType):
1958+
# Assumption here is that the names cannot be duplicated
1959+
# TODO: - detect out of order and self-referencing typevars
1960+
# - nested default types, e.g. list[T1]
1961+
for fullname, type_var in self.tvar_scope.scope.items():
1962+
type_var_name = fullname.rpartition(".")[2]
1963+
if tvar_expr.default.name == type_var_name:
1964+
tvar_expr.default = type_var
19571965
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
19581966
tvar_defs.append(tvar_def)
19591967
return base_type_exprs, tvar_defs, is_protocol

mypy/tvar_scope.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,26 @@
1515
TypeVarTupleType,
1616
TypeVarType,
1717
)
18+
from mypy.typetraverser import TypeTraverserVisitor
19+
20+
21+
class TypeVarLikeNamespaceSetter(TypeTraverserVisitor):
22+
"""Set namespace for all TypeVarLikeTypes types."""
23+
24+
def __init__(self, namespace: str) -> None:
25+
self.namespace = namespace
26+
27+
def visit_type_var(self, t: TypeVarType) -> None:
28+
t.id.namespace = self.namespace
29+
super().visit_type_var(t)
30+
31+
def visit_param_spec(self, t: ParamSpecType) -> None:
32+
t.id.namespace = self.namespace
33+
return super().visit_param_spec(t)
34+
35+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
36+
t.id.namespace = self.namespace
37+
super().visit_type_var_tuple(t)
1838

1939

2040
class TypeVarLikeScope:
@@ -88,6 +108,8 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType:
88108
i = self.func_id
89109
# TODO: Consider also using namespaces for functions
90110
namespace = ""
111+
tvar_expr.default.accept(TypeVarLikeNamespaceSetter(namespace))
112+
91113
if isinstance(tvar_expr, TypeVarExpr):
92114
tvar_def: TypeVarLikeType = TypeVarType(
93115
name=name,

mypy/typetraverser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,16 @@ def visit_type_var(self, t: TypeVarType) -> None:
6161
# Note that type variable values and upper bound aren't treated as
6262
# components, since they are components of the type variable
6363
# definition. We want to traverse everything just once.
64-
pass
64+
t.default.accept(self)
6565

6666
def visit_param_spec(self, t: ParamSpecType) -> None:
67-
pass
67+
t.default.accept(self)
6868

6969
def visit_parameters(self, t: Parameters) -> None:
7070
self.traverse_types(t.arg_types)
7171

7272
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
73-
pass
73+
t.default.accept(self)
7474

7575
def visit_literal_type(self, t: LiteralType) -> None:
7676
t.fallback.accept(self)

test-data/unit/check-typevar-defaults.test

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,54 @@ class ClassD1(Generic[T1, T2]): ...
377377
# reveal_type(c)
378378
[builtins fixtures/tuple.pyi]
379379

380+
[case testTypeVarDefaultsClassRecursive1]
381+
# flags: --disallow-any-generics
382+
from typing import Generic, TypeVar
383+
384+
T1 = TypeVar("T1", default=str)
385+
T2 = TypeVar("T2", default=T1)
386+
T3 = TypeVar("T3", default=T2)
387+
388+
class ClassD1(Generic[T1, T2]): ...
389+
390+
def func_d1(
391+
a: ClassD1,
392+
b: ClassD1[int],
393+
c: ClassD1[int, float]
394+
) -> None:
395+
reveal_type(a) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]"
396+
reveal_type(b) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]"
397+
reveal_type(c) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]"
398+
399+
k = ClassD1()
400+
reveal_type(k) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]"
401+
l = ClassD1[int]()
402+
reveal_type(l) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]"
403+
m = ClassD1[int, float]()
404+
reveal_type(m) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]"
405+
406+
class ClassD2(Generic[T1, T2, T3]): ...
407+
408+
def func_d2(
409+
a: ClassD2,
410+
b: ClassD2[int],
411+
c: ClassD2[int, float],
412+
d: ClassD2[int, float, str],
413+
) -> None:
414+
reveal_type(a) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]"
415+
reveal_type(b) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]"
416+
reveal_type(c) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]"
417+
reveal_type(d) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]"
418+
419+
k = ClassD2()
420+
reveal_type(k) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]"
421+
l = ClassD2[int]()
422+
reveal_type(l) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]"
423+
m = ClassD2[int, float]()
424+
reveal_type(m) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]"
425+
n = ClassD2[int, float, str]()
426+
reveal_type(n) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]"
427+
380428
[case testTypeVarDefaultsTypeAlias1]
381429
# flags: --disallow-any-generics
382430
from typing import Any, Dict, List, Tuple, TypeVar, Union

0 commit comments

Comments
 (0)