Skip to content

Commit 83fa0ba

Browse files
committed
Add initial support for new style TypeVar defaults (PEP 696)
1 parent cedb2fd commit 83fa0ba

File tree

9 files changed

+396
-48
lines changed

9 files changed

+396
-48
lines changed

mypy/checker.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,6 +2483,8 @@ def visit_class_def(self, defn: ClassDef) -> None:
24832483
context=defn,
24842484
code=codes.TYPE_VAR,
24852485
)
2486+
if typ.defn.type_vars:
2487+
self.check_typevar_defaults(typ.defn.type_vars)
24862488

24872489
if typ.is_protocol and typ.defn.type_vars:
24882490
self.check_protocol_variance(defn)
@@ -2546,6 +2548,15 @@ def check_init_subclass(self, defn: ClassDef) -> None:
25462548
# all other bases have already been checked.
25472549
break
25482550

2551+
def check_typevar_defaults(self, tvars: list[TypeVarLikeType]) -> None:
2552+
for tv in tvars:
2553+
if not (isinstance(tv, TypeVarType) and tv.has_default()):
2554+
continue
2555+
if not is_subtype(tv.default, tv.upper_bound):
2556+
self.fail("TypeVar default must be a subtype of the bound type", tv)
2557+
if tv.values and not any(tv.default == value for value in tv.values):
2558+
self.fail("TypeVar default must be one of the constraint types", tv)
2559+
25492560
def check_enum(self, defn: ClassDef) -> None:
25502561
assert defn.info.is_enum
25512562
if defn.info.fullname not in ENUM_BASES:
@@ -5365,6 +5376,10 @@ def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var,
53655376
del type_map[expr]
53665377

53675378
def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
5379+
sym = self.lookup_qualified(o.name.fullname)
5380+
if isinstance(sym.node, TypeAlias):
5381+
self.check_typevar_defaults(sym.node.alias_tvars)
5382+
53685383
with self.msg.filter_errors():
53695384
self.expr_checker.accept(o.value)
53705385

mypy/fastparse.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,17 +1198,15 @@ def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
11981198
for p in type_params:
11991199
bound = None
12001200
values: list[Type] = []
1201-
if sys.version_info >= (3, 13) and p.default_value is not None:
1202-
self.fail(
1203-
message_registry.TYPE_PARAM_DEFAULT_NOT_SUPPORTED,
1204-
p.lineno,
1205-
p.col_offset,
1206-
blocker=False,
1207-
)
1201+
default = None
1202+
if sys.version_info >= (3, 13):
1203+
default = TypeConverter(self.errors, line=p.lineno).visit(p.default_value)
12081204
if isinstance(p, ast_ParamSpec): # type: ignore[misc]
1209-
explicit_type_params.append(TypeParam(p.name, PARAM_SPEC_KIND, None, []))
1205+
explicit_type_params.append(TypeParam(p.name, PARAM_SPEC_KIND, None, [], default))
12101206
elif isinstance(p, ast_TypeVarTuple): # type: ignore[misc]
1211-
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_TUPLE_KIND, None, []))
1207+
explicit_type_params.append(
1208+
TypeParam(p.name, TYPE_VAR_TUPLE_KIND, None, [], default)
1209+
)
12121210
else:
12131211
if isinstance(p.bound, ast3.Tuple):
12141212
if len(p.bound.elts) < 2:
@@ -1224,7 +1222,9 @@ def translate_type_params(self, type_params: list[Any]) -> list[TypeParam]:
12241222
elif p.bound is not None:
12251223
self.validate_type_param(p)
12261224
bound = TypeConverter(self.errors, line=p.lineno).visit(p.bound)
1227-
explicit_type_params.append(TypeParam(p.name, TYPE_VAR_KIND, bound, values))
1225+
explicit_type_params.append(
1226+
TypeParam(p.name, TYPE_VAR_KIND, bound, values, default)
1227+
)
12281228
return explicit_type_params
12291229

12301230
# Return(expr? value)

mypy/message_registry.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,3 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
362362
TYPE_ALIAS_WITH_AWAIT_EXPRESSION: Final = ErrorMessage(
363363
"Await expression cannot be used within a type alias", codes.SYNTAX
364364
)
365-
366-
TYPE_PARAM_DEFAULT_NOT_SUPPORTED: Final = ErrorMessage(
367-
"Type parameter default types not supported when using Python 3.12 type parameter syntax",
368-
codes.MISC,
369-
)

mypy/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,19 +670,21 @@ def set_line(
670670

671671

672672
class TypeParam:
673-
__slots__ = ("name", "kind", "upper_bound", "values")
673+
__slots__ = ("name", "kind", "upper_bound", "values", "default")
674674

675675
def __init__(
676676
self,
677677
name: str,
678678
kind: int,
679679
upper_bound: mypy.types.Type | None,
680680
values: list[mypy.types.Type],
681+
default: mypy.types.Type | None,
681682
) -> None:
682683
self.name = name
683684
self.kind = kind
684685
self.upper_bound = upper_bound
685686
self.values = values
687+
self.default = default
686688

687689

688690
FUNCITEM_FLAGS: Final = FUNCBASE_FLAGS + [

mypy/semanal.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,7 +1808,26 @@ def analyze_type_param(
18081808
upper_bound = self.named_type("builtins.tuple", [self.object_type()])
18091809
else:
18101810
upper_bound = self.object_type()
1811-
default = AnyType(TypeOfAny.from_omitted_generics)
1811+
if type_param.default:
1812+
default = self.anal_type(
1813+
type_param.default,
1814+
allow_placeholder=True,
1815+
allow_unbound_tvars=True,
1816+
report_invalid_types=False,
1817+
allow_param_spec_literals=type_param.kind == PARAM_SPEC_KIND,
1818+
allow_tuple_literal=type_param.kind == PARAM_SPEC_KIND,
1819+
allow_unpack=type_param.kind == TYPE_VAR_TUPLE_KIND,
1820+
)
1821+
if default is None:
1822+
default = PlaceholderType(None, [], context.line)
1823+
elif type_param.kind == TYPE_VAR_KIND:
1824+
default = self.check_typevar_default(default, type_param.default)
1825+
elif type_param.kind == PARAM_SPEC_KIND:
1826+
default = self.check_paramspec_default(default, type_param.default)
1827+
elif type_param.kind == TYPE_VAR_TUPLE_KIND:
1828+
default = self.check_typevartuple_default(default, type_param.default)
1829+
else:
1830+
default = AnyType(TypeOfAny.from_omitted_generics)
18121831
if type_param.kind == TYPE_VAR_KIND:
18131832
values = []
18141833
if type_param.values:
@@ -4615,6 +4634,40 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool:
46154634
self.add_symbol(name, call.analyzed, s)
46164635
return True
46174636

4637+
def check_typevar_default(self, default: Type, context: Context) -> Type:
4638+
typ = get_proper_type(default)
4639+
if isinstance(typ, AnyType) and typ.is_from_error:
4640+
self.fail(
4641+
message_registry.TYPEVAR_ARG_MUST_BE_TYPE.format("TypeVar", "default"), context
4642+
)
4643+
return default
4644+
4645+
def check_paramspec_default(self, default: Type, context: Context) -> Type:
4646+
typ = get_proper_type(default)
4647+
if isinstance(typ, Parameters):
4648+
for i, arg_type in enumerate(typ.arg_types):
4649+
arg_ptype = get_proper_type(arg_type)
4650+
if isinstance(arg_ptype, AnyType) and arg_ptype.is_from_error:
4651+
self.fail(f"Argument {i} of ParamSpec default must be a type", context)
4652+
elif (
4653+
isinstance(typ, AnyType)
4654+
and typ.is_from_error
4655+
or not isinstance(typ, (AnyType, UnboundType))
4656+
):
4657+
self.fail(
4658+
"The default argument to ParamSpec must be a list expression, ellipsis, or a ParamSpec",
4659+
context,
4660+
)
4661+
default = AnyType(TypeOfAny.from_error)
4662+
return default
4663+
4664+
def check_typevartuple_default(self, default: Type, context: Context) -> Type:
4665+
typ = get_proper_type(default)
4666+
if not isinstance(typ, UnpackType):
4667+
self.fail("The default argument to TypeVarTuple must be an Unpacked tuple", context)
4668+
default = AnyType(TypeOfAny.from_error)
4669+
return default
4670+
46184671
def check_typevarlike_name(self, call: CallExpr, name: str, context: Context) -> bool:
46194672
"""Checks that the name of a TypeVar or ParamSpec matches its variable."""
46204673
name = unmangle(name)
@@ -4822,23 +4875,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool:
48224875
report_invalid_typevar_arg=False,
48234876
)
48244877
default = tv_arg or AnyType(TypeOfAny.from_error)
4825-
if isinstance(tv_arg, Parameters):
4826-
for i, arg_type in enumerate(tv_arg.arg_types):
4827-
typ = get_proper_type(arg_type)
4828-
if isinstance(typ, AnyType) and typ.is_from_error:
4829-
self.fail(
4830-
f"Argument {i} of ParamSpec default must be a type", param_value
4831-
)
4832-
elif (
4833-
isinstance(default, AnyType)
4834-
and default.is_from_error
4835-
or not isinstance(default, (AnyType, UnboundType))
4836-
):
4837-
self.fail(
4838-
"The default argument to ParamSpec must be a list expression, ellipsis, or a ParamSpec",
4839-
param_value,
4840-
)
4841-
default = AnyType(TypeOfAny.from_error)
4878+
default = self.check_paramspec_default(default, param_value)
48424879
else:
48434880
# ParamSpec is different from a regular TypeVar:
48444881
# arguments are not semantically valid. But, allowed in runtime.
@@ -4899,12 +4936,7 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool:
48994936
allow_unpack=True,
49004937
)
49014938
default = tv_arg or AnyType(TypeOfAny.from_error)
4902-
if not isinstance(default, UnpackType):
4903-
self.fail(
4904-
"The default argument to TypeVarTuple must be an Unpacked tuple",
4905-
param_value,
4906-
)
4907-
default = AnyType(TypeOfAny.from_error)
4939+
default = self.check_typevartuple_default(default, param_value)
49084940
else:
49094941
self.fail(f'Unexpected keyword argument "{param_name}" for "TypeVarTuple"', s)
49104942

mypy/strconv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def type_param(self, p: mypy.nodes.TypeParam) -> list[Any]:
349349
a.append(p.upper_bound)
350350
if p.values:
351351
a.append(("Values", p.values))
352+
if p.default:
353+
a.append(("Default", [p.default]))
352354
return [("TypeParam", a)]
353355

354356
# Expressions

mypy/test/testparse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class ParserSuite(DataSuite):
2525
files.remove("parse-python310.test")
2626
if sys.version_info < (3, 12):
2727
files.remove("parse-python312.test")
28+
if sys.version_info < (3, 13):
29+
files.remove("parse-python313.test")
2830

2931
def run_case(self, testcase: DataDrivenTestCase) -> None:
3032
test_parser(testcase)
@@ -43,6 +45,8 @@ def test_parser(testcase: DataDrivenTestCase) -> None:
4345
options.python_version = (3, 10)
4446
elif testcase.file.endswith("python312.test"):
4547
options.python_version = (3, 12)
48+
elif testcase.file.endswith("python313.test"):
49+
options.python_version = (3, 13)
4650
else:
4751
options.python_version = defaults.PYTHON3_VERSION
4852

0 commit comments

Comments
 (0)