Skip to content

Commit fe0aa92

Browse files
committed
stubtest: get better signatures for __init__ of C classes
When an __init__ method has the generic C-class signature, check the underlying class for a better signature.
1 parent 82de0d8 commit fe0aa92

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed

mypy/stubtest.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
from __future__ import annotations
88

99
import argparse
10+
import ast
1011
import collections.abc
1112
import copy
1213
import enum
1314
import functools
1415
import importlib
1516
import importlib.machinery
1617
import inspect
18+
import itertools
1719
import os
1820
import pkgutil
1921
import re
@@ -1526,7 +1528,202 @@ def is_read_only_property(runtime: object) -> bool:
15261528
return isinstance(runtime, property) and runtime.fset is None
15271529

15281530

1531+
def _signature_fromstr(
1532+
cls: type[inspect.Signature], obj: Any, s: str, skip_bound_arg: bool = True
1533+
) -> inspect.Signature:
1534+
"""Private helper to parse content of '__text_signature__'
1535+
and return a Signature based on it.
1536+
1537+
This is a copy of inspect._signature_fromstr from 3.13, which we need
1538+
for python/cpython#115270, an important fix for working with
1539+
built-in instance methods.
1540+
"""
1541+
Parameter = cls._parameter_cls # type: ignore[attr-defined]
1542+
1543+
if sys.version_info >= (3, 12):
1544+
clean_signature, self_parameter = inspect._signature_strip_non_python_syntax(s) # type: ignore[attr-defined]
1545+
else:
1546+
clean_signature, self_parameter, last_positional_only = inspect._signature_strip_non_python_syntax(s) # type: ignore[attr-defined]
1547+
1548+
program = "def foo" + clean_signature + ": pass"
1549+
1550+
try:
1551+
module_ast = ast.parse(program)
1552+
except SyntaxError:
1553+
module_ast = None
1554+
1555+
if not isinstance(module_ast, ast.Module):
1556+
raise ValueError("{!r} builtin has invalid signature".format(obj))
1557+
1558+
f = module_ast.body[0]
1559+
assert isinstance(f, ast.FunctionDef)
1560+
1561+
parameters = []
1562+
empty = Parameter.empty
1563+
1564+
module = None
1565+
module_dict: dict[str, Any] = {}
1566+
1567+
module_name = getattr(obj, "__module__", None)
1568+
if not module_name:
1569+
objclass = getattr(obj, "__objclass__", None)
1570+
module_name = getattr(objclass, "__module__", None)
1571+
1572+
if module_name:
1573+
module = sys.modules.get(module_name, None)
1574+
if module:
1575+
module_dict = module.__dict__
1576+
sys_module_dict = sys.modules.copy()
1577+
1578+
def parse_name(node: ast.arg) -> str:
1579+
assert isinstance(node, ast.arg)
1580+
if node.annotation is not None:
1581+
raise ValueError("Annotations are not currently supported")
1582+
return node.arg
1583+
1584+
def wrap_value(s: str) -> ast.Constant:
1585+
try:
1586+
value = eval(s, module_dict)
1587+
except NameError:
1588+
try:
1589+
value = eval(s, sys_module_dict)
1590+
except NameError:
1591+
raise ValueError
1592+
1593+
if isinstance(value, (str, int, float, bytes, bool, type(None))):
1594+
return ast.Constant(value)
1595+
raise ValueError
1596+
1597+
class RewriteSymbolics(ast.NodeTransformer):
1598+
def visit_Attribute(self, node: ast.Attribute) -> Any:
1599+
a = []
1600+
n: ast.expr = node
1601+
while isinstance(n, ast.Attribute):
1602+
a.append(n.attr)
1603+
n = n.value
1604+
if not isinstance(n, ast.Name):
1605+
raise ValueError
1606+
a.append(n.id)
1607+
value = ".".join(reversed(a))
1608+
return wrap_value(value)
1609+
1610+
def visit_Name(self, node: ast.Name) -> Any:
1611+
if not isinstance(node.ctx, ast.Load):
1612+
raise ValueError()
1613+
return wrap_value(node.id)
1614+
1615+
def visit_BinOp(self, node: ast.BinOp) -> Any:
1616+
# Support constant folding of a couple simple binary operations
1617+
# commonly used to define default values in text signatures
1618+
left = self.visit(node.left)
1619+
right = self.visit(node.right)
1620+
if not isinstance(left, ast.Constant) or not isinstance(right, ast.Constant):
1621+
raise ValueError
1622+
if isinstance(node.op, ast.Add):
1623+
return ast.Constant(left.value + right.value)
1624+
elif isinstance(node.op, ast.Sub):
1625+
return ast.Constant(left.value - right.value)
1626+
elif isinstance(node.op, ast.BitOr):
1627+
return ast.Constant(left.value | right.value)
1628+
raise ValueError
1629+
1630+
def p(name_node: ast.arg, default_node: Any, default: Any = empty) -> None:
1631+
name = parse_name(name_node)
1632+
if default_node and default_node is not inspect._empty:
1633+
try:
1634+
default_node = RewriteSymbolics().visit(default_node)
1635+
default = ast.literal_eval(default_node)
1636+
except ValueError:
1637+
raise ValueError("{!r} builtin has invalid signature".format(obj)) from None
1638+
parameters.append(Parameter(name, kind, default=default, annotation=empty))
1639+
1640+
# non-keyword-only parameters
1641+
if sys.version_info >= (3, 12):
1642+
total_non_kw_args = len(f.args.posonlyargs) + len(f.args.args)
1643+
required_non_kw_args = total_non_kw_args - len(f.args.defaults)
1644+
defaults = itertools.chain(itertools.repeat(None, required_non_kw_args), f.args.defaults)
1645+
1646+
kind = Parameter.POSITIONAL_ONLY
1647+
for name, default in zip(f.args.posonlyargs, defaults):
1648+
p(name, default)
1649+
1650+
kind = Parameter.POSITIONAL_OR_KEYWORD
1651+
for name, default in zip(f.args.args, defaults):
1652+
p(name, default)
1653+
1654+
else:
1655+
args = reversed(f.args.args)
1656+
defaults = reversed(f.args.defaults)
1657+
iter = itertools.zip_longest(args, defaults, fillvalue=None)
1658+
if last_positional_only is not None:
1659+
kind = Parameter.POSITIONAL_ONLY
1660+
else:
1661+
kind = Parameter.POSITIONAL_OR_KEYWORD
1662+
for i, (name, default) in enumerate(reversed(list(iter))):
1663+
p(name, default)
1664+
if i == last_positional_only:
1665+
kind = Parameter.POSITIONAL_OR_KEYWORD
1666+
1667+
# *args
1668+
if f.args.vararg:
1669+
kind = Parameter.VAR_POSITIONAL
1670+
p(f.args.vararg, empty)
1671+
1672+
# keyword-only arguments
1673+
kind = Parameter.KEYWORD_ONLY
1674+
for name, default in zip(f.args.kwonlyargs, f.args.kw_defaults):
1675+
p(name, default)
1676+
1677+
# **kwargs
1678+
if f.args.kwarg:
1679+
kind = Parameter.VAR_KEYWORD
1680+
p(f.args.kwarg, empty)
1681+
1682+
if self_parameter is not None:
1683+
# Possibly strip the bound argument:
1684+
# - We *always* strip first bound argument if
1685+
# it is a module.
1686+
# - We don't strip first bound argument if
1687+
# skip_bound_arg is False.
1688+
assert parameters
1689+
_self = getattr(obj, "__self__", None)
1690+
self_isbound = _self is not None
1691+
self_ismodule = inspect.ismodule(_self)
1692+
if self_isbound and (self_ismodule or skip_bound_arg):
1693+
parameters.pop(0)
1694+
else:
1695+
# for builtins, self parameter is always positional-only!
1696+
p = parameters[0].replace(kind=Parameter.POSITIONAL_ONLY)
1697+
parameters[0] = p
1698+
1699+
return cls(parameters, return_annotation=cls.empty)
1700+
1701+
15291702
def safe_inspect_signature(runtime: Any) -> inspect.Signature | None:
1703+
if (
1704+
hasattr(runtime, "__name__")
1705+
and runtime.__name__ == "__init__"
1706+
and hasattr(runtime, "__text_signature__")
1707+
and runtime.__text_signature__ == "($self, /, *args, **kwargs)"
1708+
and hasattr(runtime, "__objclass__")
1709+
and runtime.__objclass__ is not object
1710+
and hasattr(runtime.__objclass__, "__text_signature__")
1711+
and runtime.__objclass__.__text_signature__ is not None
1712+
):
1713+
# This is an __init__ method with the generic C-class signature.
1714+
# In this case, the underlying class usually has a better signature,
1715+
# which we can convert into an __init__ signature by adding $self
1716+
# at the start. If we hit an error, failover to the normal
1717+
# path without trying to recover.
1718+
if "/" in runtime.__objclass__.__text_signature__:
1719+
new_sig = f"($self, {runtime.__objclass__.__text_signature__[1:]}"
1720+
else:
1721+
new_sig = f"($self, /, {runtime.__objclass__.__text_signature__[1:]}"
1722+
try:
1723+
return _signature_fromstr(inspect.Signature, runtime, new_sig)
1724+
except Exception:
1725+
pass
1726+
15301727
try:
15311728
try:
15321729
return inspect.signature(runtime)

0 commit comments

Comments
 (0)