1313from typing import Any , Callable
1414
1515import mypy .stubtest
16+ from mypy import build , nodes
17+ from mypy .modulefinder import BuildSource
18+ from mypy .options import Options
1619from mypy .stubtest import parse_options , test_stubs
20+ from mypy .test .config import test_temp_dir
1721from mypy .test .data import root_dir
1822
1923
@@ -158,6 +162,14 @@ def __invert__(self: _T) -> _T: pass
158162"""
159163
160164
165+ def build_helper (source : str ) -> build .BuildResult :
166+ return build .build (
167+ sources = [BuildSource ("main.pyi" , None , textwrap .dedent (source ))],
168+ options = Options (),
169+ alt_lib_path = test_temp_dir ,
170+ )
171+
172+
161173def run_stubtest_with_stderr (
162174 stub : str , runtime : str , options : list [str ], config_file : str | None = None
163175) -> tuple [str , str ]:
@@ -842,6 +854,18 @@ def f2(self, *a) -> int: ...
842854 """ ,
843855 error = None ,
844856 )
857+ yield Case (
858+ stub = """
859+ @overload
860+ def f(a: int) -> int: ...
861+ @overload
862+ def f(a: int, b: str, /) -> str: ...
863+ """ ,
864+ runtime = """
865+ def f(a, *args): ...
866+ """ ,
867+ error = None ,
868+ )
845869
846870 @collect_cases
847871 def test_property (self ) -> Iterator [Case ]:
@@ -1407,14 +1431,9 @@ def spam(x=Flags4(0)): pass
14071431 stub = """
14081432 import sys
14091433 from typing import Final, Literal
1410- from typing_extensions import disjoint_base
1411- if sys.version_info >= (3, 12):
1412- class BytesEnum(bytes, enum.Enum):
1413- a = b'foo'
1414- else:
1415- @disjoint_base
1416- class BytesEnum(bytes, enum.Enum):
1417- a = b'foo'
1434+ class BytesEnum(bytes, enum.Enum):
1435+ a = b'foo'
1436+
14181437 FOO: Literal[BytesEnum.a]
14191438 BAR: Final = BytesEnum.a
14201439 BAZ: BytesEnum
@@ -1430,6 +1449,31 @@ class BytesEnum(bytes, enum.Enum):
14301449 """ ,
14311450 error = None ,
14321451 )
1452+ yield Case (
1453+ stub = """
1454+ class HasSlotsAndNothingElse:
1455+ __slots__ = ("x",)
1456+ x: int
1457+
1458+ class HasInheritedSlots(HasSlotsAndNothingElse):
1459+ pass
1460+
1461+ class HasEmptySlots:
1462+ __slots__ = ()
1463+ """ ,
1464+ runtime = """
1465+ class HasSlotsAndNothingElse:
1466+ __slots__ = ("x",)
1467+ x: int
1468+
1469+ class HasInheritedSlots(HasSlotsAndNothingElse):
1470+ pass
1471+
1472+ class HasEmptySlots:
1473+ __slots__ = ()
1474+ """ ,
1475+ error = None ,
1476+ )
14331477
14341478 @collect_cases
14351479 def test_decorator (self ) -> Iterator [Case ]:
@@ -1673,6 +1717,53 @@ def __next__(self) -> object: ...
16731717 """ ,
16741718 error = None ,
16751719 )
1720+ yield Case (
1721+ runtime = """
1722+ class IsDisjointBaseBecauseItHasSlots:
1723+ __slots__ = ("a",)
1724+ a: int
1725+ """ ,
1726+ stub = """
1727+ from typing_extensions import disjoint_base
1728+
1729+ @disjoint_base
1730+ class IsDisjointBaseBecauseItHasSlots:
1731+ a: int
1732+ """ ,
1733+ error = "test_module.IsDisjointBaseBecauseItHasSlots" ,
1734+ )
1735+ yield Case (
1736+ runtime = """
1737+ class IsFinalSoDisjointBaseIsRedundant: ...
1738+ """ ,
1739+ stub = """
1740+ from typing_extensions import disjoint_base, final
1741+
1742+ @final
1743+ @disjoint_base
1744+ class IsFinalSoDisjointBaseIsRedundant: ...
1745+ """ ,
1746+ error = "test_module.IsFinalSoDisjointBaseIsRedundant" ,
1747+ )
1748+ yield Case (
1749+ runtime = """
1750+ import enum
1751+
1752+ class IsEnumWithMembersSoDisjointBaseIsRedundant(enum.Enum):
1753+ A = 1
1754+ B = 2
1755+ """ ,
1756+ stub = """
1757+ from typing_extensions import disjoint_base
1758+ import enum
1759+
1760+ @disjoint_base
1761+ class IsEnumWithMembersSoDisjointBaseIsRedundant(enum.Enum):
1762+ A = 1
1763+ B = 2
1764+ """ ,
1765+ error = "test_module.IsEnumWithMembersSoDisjointBaseIsRedundant" ,
1766+ )
16761767
16771768 @collect_cases
16781769 def test_has_runtime_final_decorator (self ) -> Iterator [Case ]:
@@ -2545,6 +2636,31 @@ class _X1: ...
25452636 error = None ,
25462637 )
25472638
2639+ @collect_cases
2640+ def test_type_default_protocol (self ) -> Iterator [Case ]:
2641+ yield Case (
2642+ stub = """
2643+ from typing import Protocol
2644+
2645+ class _FormatterClass(Protocol):
2646+ def __call__(self, *, prog: str) -> HelpFormatter: ...
2647+
2648+ class ArgumentParser:
2649+ def __init__(self, formatter_class: _FormatterClass = ...) -> None: ...
2650+
2651+ class HelpFormatter:
2652+ def __init__(self, prog: str, indent_increment: int = 2) -> None: ...
2653+ """ ,
2654+ runtime = """
2655+ class HelpFormatter:
2656+ def __init__(self, prog, indent_increment=2) -> None: ...
2657+
2658+ class ArgumentParser:
2659+ def __init__(self, formatter_class=HelpFormatter): ...
2660+ """ ,
2661+ error = None ,
2662+ )
2663+
25482664
25492665def remove_color_code (s : str ) -> str :
25502666 return re .sub ("\\ x1b.*?m" , "" , s ) # this works!
@@ -2558,8 +2674,8 @@ def test_output(self) -> None:
25582674 options = [],
25592675 )
25602676 expected = (
2561- f'error: { TEST_MODULE_NAME } .bad is inconsistent, stub argument "number" differs '
2562- 'from runtime argument "num"\n '
2677+ f'error: { TEST_MODULE_NAME } .bad is inconsistent, stub parameter "number" differs '
2678+ 'from runtime parameter "num"\n '
25632679 f"Stub: in file { TEST_MODULE_NAME } .pyi:1\n "
25642680 "def (number: builtins.int, text: builtins.str)\n "
25652681 f"Runtime: in file { TEST_MODULE_NAME } .py:1\n def (num, text)\n \n "
@@ -2574,7 +2690,9 @@ def test_output(self) -> None:
25742690 )
25752691 expected = (
25762692 "{}.bad is inconsistent, "
2577- 'stub argument "number" differs from runtime argument "num"\n ' .format (TEST_MODULE_NAME )
2693+ 'stub parameter "number" differs from runtime parameter "num"\n ' .format (
2694+ TEST_MODULE_NAME
2695+ )
25782696 )
25792697 assert output == expected
25802698
@@ -2721,6 +2839,25 @@ def test_builtin_signature_with_unrepresentable_default(self) -> None:
27212839 == "def (self, sep = ..., bytes_per_sep = ...)"
27222840 )
27232841
2842+ def test_overload_signature (self ) -> None :
2843+ # The same argument as both positional-only and pos-or-kw in
2844+ # different overloads previously produced incorrect signatures
2845+ source = """
2846+ from typing import overload
2847+ @overload
2848+ def myfunction(arg: int) -> None: ...
2849+ @overload
2850+ def myfunction(arg: str, /) -> None: ...
2851+ """
2852+ result = build_helper (source )
2853+ stub = result .files ["__main__" ].names ["myfunction" ].node
2854+ assert isinstance (stub , nodes .OverloadedFuncDef )
2855+ sig = mypy .stubtest .Signature .from_overloadedfuncdef (stub )
2856+ if sys .version_info >= (3 , 10 ):
2857+ assert str (sig ) == "def (arg: builtins.int | builtins.str)"
2858+ else :
2859+ assert str (sig ) == "def (arg: Union[builtins.int, builtins.str])"
2860+
27242861 def test_config_file (self ) -> None :
27252862 runtime = "temp = 5\n "
27262863 stub = "from decimal import Decimal\n temp: Decimal\n "
0 commit comments