Skip to content

Commit f51b699

Browse files
authored
stubtest: handle overloads with mixed pos-only params (#18287)
Fixes #17023 Stubtest should only mangle positional-only parameter names if they're positional-only in all branches of the overload. The signatures get really ugly and wrong otherwise. I'm not sure if I did the new `test_overload_signature` in the best way. I couldn't figure out a way to get covert a string into a `nodes.OverloadedFuncDef` object with any of the techniques in existing tests in `teststubtest.py`. Maybe the new test case is sufficient, but I wanted to test the signature generation directly.
1 parent ac4cacb commit f51b699

File tree

2 files changed

+65
-8
lines changed

2 files changed

+65
-8
lines changed

mypy/stubtest.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -954,22 +954,36 @@ def from_overloadedfuncdef(stub: nodes.OverloadedFuncDef) -> Signature[nodes.Arg
954954
# For most dunder methods, just assume all args are positional-only
955955
assume_positional_only = is_dunder(stub.name, exclude_special=True)
956956

957-
all_args: dict[str, list[tuple[nodes.Argument, int]]] = {}
957+
is_arg_pos_only: defaultdict[str, set[bool]] = defaultdict(set)
958958
for func in map(_resolve_funcitem_from_decorator, stub.items):
959959
assert func is not None, "Failed to resolve decorated overload"
960960
args = maybe_strip_cls(stub.name, func.arguments)
961961
for index, arg in enumerate(args):
962-
# For positional-only args, we allow overloads to have different names for the same
963-
# argument. To accomplish this, we just make up a fake index-based name.
964-
name = (
965-
f"__{index}"
966-
if arg.variable.name.startswith("__")
962+
if (
963+
arg.variable.name.startswith("__")
967964
or arg.pos_only
968965
or assume_positional_only
969966
or arg.variable.name.strip("_") == "self"
970967
or (index == 0 and arg.variable.name.strip("_") == "cls")
971-
else arg.variable.name
972-
)
968+
):
969+
is_arg_pos_only[arg.variable.name].add(True)
970+
else:
971+
is_arg_pos_only[arg.variable.name].add(False)
972+
973+
all_args: dict[str, list[tuple[nodes.Argument, int]]] = {}
974+
for func in map(_resolve_funcitem_from_decorator, stub.items):
975+
assert func is not None, "Failed to resolve decorated overload"
976+
args = maybe_strip_cls(stub.name, func.arguments)
977+
for index, arg in enumerate(args):
978+
# For positional-only args, we allow overloads to have different names for the same
979+
# argument. To accomplish this, we just make up a fake index-based name.
980+
# We can only use the index-based name if the argument is always
981+
# positional only. Sometimes overloads have an arg as positional-only
982+
# in some but not all branches of the overload.
983+
name = arg.variable.name
984+
if is_arg_pos_only[name] == {True}:
985+
name = f"__{index}"
986+
973987
all_args.setdefault(name, []).append((arg, index))
974988

975989
def get_position(arg_name: str) -> int:

mypy/test/teststubtest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from typing import Any, Callable
1414

1515
import mypy.stubtest
16+
from mypy import build, nodes
17+
from mypy.modulefinder import BuildSource
18+
from mypy.options import Options
1619
from mypy.stubtest import parse_options, test_stubs
20+
from mypy.test.config import test_temp_dir
1721
from mypy.test.data import root_dir
1822

1923

@@ -158,6 +162,14 @@ def __invert__(self: _T) -> _T: pass
158162
"""
159163

160164

165+
def build_helper(source: str) -> build.BuildResult:
166+
return build.build(
167+
sources=[BuildSource("main.pyi", None, textwrap.dedent(source))],
168+
options=Options(),
169+
alt_lib_path=test_temp_dir,
170+
)
171+
172+
161173
def run_stubtest_with_stderr(
162174
stub: str, runtime: str, options: list[str], config_file: str | None = None
163175
) -> tuple[str, str]:
@@ -842,6 +854,18 @@ def f2(self, *a) -> int: ...
842854
""",
843855
error=None,
844856
)
857+
yield Case(
858+
stub="""
859+
@overload
860+
def f(a: int) -> int: ...
861+
@overload
862+
def f(a: int, b: str, /) -> str: ...
863+
""",
864+
runtime="""
865+
def f(a, *args): ...
866+
""",
867+
error=None,
868+
)
845869

846870
@collect_cases
847871
def test_property(self) -> Iterator[Case]:
@@ -2790,6 +2814,25 @@ def test_builtin_signature_with_unrepresentable_default(self) -> None:
27902814
== "def (self, sep = ..., bytes_per_sep = ...)"
27912815
)
27922816

2817+
def test_overload_signature(self) -> None:
2818+
# The same argument as both positional-only and pos-or-kw in
2819+
# different overloads previously produced incorrect signatures
2820+
source = """
2821+
from typing import overload
2822+
@overload
2823+
def myfunction(arg: int) -> None: ...
2824+
@overload
2825+
def myfunction(arg: str, /) -> None: ...
2826+
"""
2827+
result = build_helper(source)
2828+
stub = result.files["__main__"].names["myfunction"].node
2829+
assert isinstance(stub, nodes.OverloadedFuncDef)
2830+
sig = mypy.stubtest.Signature.from_overloadedfuncdef(stub)
2831+
if sys.version_info >= (3, 10):
2832+
assert str(sig) == "def (arg: builtins.int | builtins.str)"
2833+
else:
2834+
assert str(sig) == "def (arg: Union[builtins.int, builtins.str])"
2835+
27932836
def test_config_file(self) -> None:
27942837
runtime = "temp = 5\n"
27952838
stub = "from decimal import Decimal\ntemp: Decimal\n"

0 commit comments

Comments
 (0)