Skip to content

Commit 45bab8c

Browse files
committed
When selecting an overload item for constraint template matching, treat any ParamSpec in the template as free
1 parent a0665e1 commit 45bab8c

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

mypy/constraints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,7 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType)
14341434
is_compat=mypy.subtypes.is_subtype,
14351435
is_proper_subtype=False,
14361436
ignore_return=True,
1437+
map_template_paramspec=True,
14371438
):
14381439
return item
14391440
# Fall back to the first item if we can't find a match. This is totally arbitrary --

mypy/subtypes.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,7 @@ def is_callable_compatible(
15761576
check_args_covariantly: bool = False,
15771577
allow_partial_overlap: bool = False,
15781578
strict_concatenate: bool = False,
1579+
map_template_paramspec: bool = False,
15791580
) -> bool:
15801581
"""Is the left compatible with the right, using the provided compatibility check?
15811582
@@ -1717,6 +1718,7 @@ def g(x: int) -> int: ...
17171718
ignore_pos_arg_names=ignore_pos_arg_names,
17181719
allow_partial_overlap=allow_partial_overlap,
17191720
strict_concatenate_check=strict_concatenate_check,
1721+
map_template_paramspec=map_template_paramspec,
17201722
)
17211723

17221724

@@ -1753,6 +1755,7 @@ def are_parameters_compatible(
17531755
ignore_pos_arg_names: bool = False,
17541756
allow_partial_overlap: bool = False,
17551757
strict_concatenate_check: bool = False,
1758+
map_template_paramspec: bool = False,
17561759
) -> bool:
17571760
"""Helper function for is_callable_compatible, used for Parameter compatibility"""
17581761
if right.is_ellipsis_args and not is_proper_subtype:
@@ -1781,6 +1784,8 @@ def are_parameters_compatible(
17811784
# a subtype of erased template type.
17821785
trivial_vararg_suffix = True
17831786

1787+
right_is_pspec = map_template_paramspec and right.param_spec() is not None
1788+
17841789
# Match up corresponding arguments and check them for compatibility. In
17851790
# every pair (argL, argR) of corresponding arguments from L and R, argL must
17861791
# be "more general" than argR if L is to be a subtype of R.
@@ -1817,7 +1822,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
18171822
_incompatible(left_star, right_star)
18181823
and not trivial_vararg_suffix
18191824
or _incompatible(left_star2, right_star2)
1820-
):
1825+
) and not right_is_pspec:
18211826
return False
18221827

18231828
# Phase 1b: Check non-star args: for every arg right can accept, left must
@@ -1848,7 +1853,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
18481853
# arguments. Get all further positional args of left, and make sure
18491854
# they're more general than the corresponding member in right.
18501855
# TODO: handle suffix in UnpackType (i.e. *args: *Tuple[Ts, X, Y]).
1851-
if right_star is not None and not trivial_vararg_suffix:
1856+
if right_star is not None and not trivial_vararg_suffix and not right_is_pspec:
18521857
# Synthesize an anonymous formal argument for the right
18531858
right_by_position = right.try_synthesizing_arg_from_vararg(None)
18541859
assert right_by_position is not None
@@ -1875,7 +1880,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
18751880
# Phase 1d: Check kw args. Right has an infinite series of optional named
18761881
# arguments. Get all further named args of left, and make sure
18771882
# they're more general than the corresponding member in right.
1878-
if right_star2 is not None:
1883+
if right_star2 is not None and not right_is_pspec:
18791884
right_names = {name for name in right.arg_names if name is not None}
18801885
left_only_names = set()
18811886
for name, kind in zip(left.arg_names, left.arg_kinds):

test-data/unit/check-parameter-specification.test

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2603,3 +2603,26 @@ def run3(predicate: Callable[Concatenate[int, str, _P], None], *args: _P.args, *
26032603
# E: Argument 1 has incompatible type "*tuple[Union[int, str], ...]"; expected "str" \
26042604
# E: Argument 1 has incompatible type "*tuple[Union[int, str], ...]"; expected "_P.args"
26052605
[builtins fixtures/paramspec.pyi]
2606+
2607+
[case testParamSpecOverloadProtocol]
2608+
from typing import ParamSpec, Protocol, TypeVar, overload
2609+
2610+
_A_contra = TypeVar("_A_contra", bound=str, contravariant=True)
2611+
_P = ParamSpec("_P")
2612+
2613+
class Callback(Protocol[_A_contra, _P]):
2614+
def method(self, a: _A_contra, *args: _P.args, **kwargs: _P.kwargs) -> None: ...
2615+
2616+
class Impl:
2617+
@overload
2618+
def method(self, a: int, b: str) -> None: ...
2619+
@overload
2620+
def method(self, a: str, b: int) -> None: ...
2621+
def method(self, a, b) -> None: ...
2622+
2623+
def accepts_callback(cb: Callback[str, _P], *args: _P.args, **kwargs: _P.kwargs) -> int:
2624+
return 1
2625+
2626+
a = accepts_callback(Impl(), 1)
2627+
reveal_type(a) # N: Revealed type is "builtins.int"
2628+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)