Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mypy/stubdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class FunctionSig(NamedTuple):
name: str
args: list[ArgSig]
ret_type: str | None
type_args: str = "" # TODO implement in stubgenc and remove the default

def is_special_method(self) -> bool:
return bool(
Expand Down Expand Up @@ -141,8 +142,13 @@ def format_sig(
retfield = " -> " + ret_type

prefix = "async " if is_async else ""
sig = "{indent}{prefix}def {name}({args}){ret}:".format(
indent=indent, prefix=prefix, name=self.name, args=", ".join(args), ret=retfield
sig = "{indent}{prefix}def {name}{type_args}({args}){ret}:".format(
indent=indent,
prefix=prefix,
name=self.name,
args=", ".join(args),
ret=retfield,
type_args=self.type_args,
)
if docstring:
suffix = f"\n{indent} {mypy.util.quote_docstring(docstring)}"
Expand Down
20 changes: 18 additions & 2 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
StrExpr,
TempNode,
TupleExpr,
TypeAliasStmt,
TypeInfo,
UnaryExpr,
Var,
Expand Down Expand Up @@ -398,6 +399,9 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
for name in get_assigned_names(o.lvalues):
self.names.add(name)

def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
self.names.add(o.name.name)


def find_referenced_names(file: MypyFile) -> set[str]:
finder = ReferenceFinder()
Expand Down Expand Up @@ -507,7 +511,8 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
def get_default_function_sig(self, func_def: FuncDef, ctx: FunctionContext) -> FunctionSig:
args = self._get_func_args(func_def, ctx)
retname = self._get_func_return(func_def, ctx)
return FunctionSig(func_def.name, args, retname)
type_args = self.format_type_args(func_def)
return FunctionSig(func_def.name, args, retname, type_args)

def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
args: list[ArgSig] = []
Expand Down Expand Up @@ -765,7 +770,8 @@ def visit_class_def(self, o: ClassDef) -> None:
self.import_tracker.add_import("abc")
self.import_tracker.require_name("abc")
bases = f"({', '.join(base_types)})" if base_types else ""
self.add(f"{self._indent}class {o.name}{bases}:\n")
type_args = self.format_type_args(o)
self.add(f"{self._indent}class {o.name}{type_args}{bases}:\n")
self.indent()
if self._include_docstrings and o.docstring:
docstring = mypy.util.quote_docstring(o.docstring)
Expand Down Expand Up @@ -1101,6 +1107,16 @@ def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None:
self.record_name(lvalue.name)
self._vars[-1].append(lvalue.name)

def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
"""Type aliases defined with the `type` keyword (PEP 695)."""
p = AliasPrinter(self)
name = o.name.name
rvalue = o.value.expr()
type_args = self.format_type_args(o)
self.add(f"{self._indent}type {name}{type_args} = {rvalue.accept(p)}\n")
self.record_name(name)
self._vars[-1].append(name)

def visit_if_stmt(self, o: IfStmt) -> None:
# Ignore if __name__ == '__main__'.
expr = o.expr[0]
Expand Down
26 changes: 26 additions & 0 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import mypy.options
from mypy.modulefinder import ModuleNotFoundReason
from mypy.moduleinspect import InspectError, ModuleInspect
from mypy.nodes import PARAM_SPEC_KIND, TYPE_VAR_TUPLE_KIND, ClassDef, FuncDef, TypeAliasStmt
from mypy.stubdoc import ArgSig, FunctionSig
from mypy.types import AnyType, NoneType, Type, TypeList, TypeStrVisitor, UnboundType, UnionType

Expand Down Expand Up @@ -777,6 +778,31 @@ def format_func_def(
)
return lines

def format_type_args(self, o: TypeAliasStmt | FuncDef | ClassDef) -> str:
if not o.type_args:
return ""
p = AnnotationPrinter(self)
type_args_list: list[str] = []
for type_arg in o.type_args:
if type_arg.kind == PARAM_SPEC_KIND:
prefix = "**"
elif type_arg.kind == TYPE_VAR_TUPLE_KIND:
prefix = "*"
else:
prefix = ""
if type_arg.upper_bound:
bound_or_values = f": {type_arg.upper_bound.accept(p)}"
elif type_arg.values:
bound_or_values = f": ({', '.join(v.accept(p) for v in type_arg.values)})"
else:
bound_or_values = ""
if type_arg.default:
default = f" = {type_arg.default.accept(p)}"
else:
default = ""
type_args_list.append(f"{prefix}{type_arg.name}{bound_or_values}{default}")
return "[" + ", ".join(type_args_list) + "]"

def print_annotation(
self,
t: Type,
Expand Down
76 changes: 76 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -4415,3 +4415,79 @@ class Test(Whatever, a=1, b='b', c=True, d=1.5, e=None, f=1j, g=b'123'): ...
class Test(Whatever, keyword=SomeName * 2, attr=SomeName.attr): ...
[out]
class Test(Whatever, keyword=SomeName * 2, attr=SomeName.attr): ...

[case testPEP695GenericClass]
# flags: --python-version=3.12

class C[T]: ...
class C1[T1](int): ...
class C2[T2: int]: ...
class C3[T3: str | bytes]: ...
class C4[T4: (str, bytes)]: ...

[out]
class C[T]: ...
class C1[T1](int): ...
class C2[T2: int]: ...
class C3[T3: str | bytes]: ...
class C4[T4: (str, bytes)]: ...

[case testPEP695GenericFunction]
# flags: --python-version=3.12

def f1[T1](): ...
def f2[T2: int](): ...
def f3[T3: str | bytes](): ...
def f4[T4: (str, bytes)](): ...

[out]
def f1[T1]() -> None: ...
def f2[T2: int]() -> None: ...
def f3[T3: str | bytes]() -> None: ...
def f4[T4: (str, bytes)]() -> None: ...

[case testPEP695TypeAlias]
# flags: --python-version=3.12

type Alias = int | str
type Alias1[T1] = list[T1] | set[T1]
type Alias2[T2: int] = list[T2] | set[T2]
type Alias3[T3: str | bytes] = list[T3] | set[T3]
type Alias4[T4: (str, bytes)] = list[T4] | set[T4]

[out]
type Alias = int | str
type Alias1[T1] = list[T1] | set[T1]
type Alias2[T2: int] = list[T2] | set[T2]
type Alias3[T3: str | bytes] = list[T3] | set[T3]
type Alias4[T4: (str, bytes)] = list[T4] | set[T4]

[case testPEP695Syntax_semanal]
# flags: --python-version=3.12

class C[T]: ...
def f[S](): ...
type A[R] = list[R]

[out]
class C[T]: ...

def f[S]() -> None: ...
type A[R] = list[R]

[case testPEP696Syntax]
# flags: --python-version=3.13

type Alias1[T1 = int] = list[T1] | set[T1]
type Alias2[T2: int | float = int] = list[T2] | set[T2]
class C3[T3 = int]: ...
class C4[T4: int | float = int](list[T4]): ...
def f5[T5 = int](): ...

[out]
type Alias1[T1 = int] = list[T1] | set[T1]
type Alias2[T2: int | float = int] = list[T2] | set[T2]
class C3[T3 = int]: ...
class C4[T4: int | float = int](list[T4]): ...

def f5[T5 = int]() -> None: ...
Loading