Skip to content

Commit 3b2297f

Browse files
committed
Reuse params preprocessing logic for generic functions
1 parent 5cdb753 commit 3b2297f

File tree

2 files changed

+109
-29
lines changed

2 files changed

+109
-29
lines changed

mypy/checkexpr.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,33 +1716,9 @@ def check_callable_call(
17161716
callee = callee.copy_modified(ret_type=fresh_ret_type)
17171717

17181718
if callee.is_generic():
1719-
need_refresh = any(
1720-
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
1719+
callee, formal_to_actual = self.adjust_generic_callable_params_mapping(
1720+
callee, args, arg_kinds, arg_names, formal_to_actual, context
17211721
)
1722-
callee = freshen_function_type_vars(callee)
1723-
callee = self.infer_function_type_arguments_using_context(callee, context)
1724-
if need_refresh:
1725-
# Argument kinds etc. may have changed due to
1726-
# ParamSpec or TypeVarTuple variables being replaced with an arbitrary
1727-
# number of arguments; recalculate actual-to-formal map
1728-
formal_to_actual = map_actuals_to_formals(
1729-
arg_kinds,
1730-
arg_names,
1731-
callee.arg_kinds,
1732-
callee.arg_names,
1733-
lambda i: self.accept(args[i]),
1734-
)
1735-
callee = self.infer_function_type_arguments(
1736-
callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context
1737-
)
1738-
if need_refresh:
1739-
formal_to_actual = map_actuals_to_formals(
1740-
arg_kinds,
1741-
arg_names,
1742-
callee.arg_kinds,
1743-
callee.arg_names,
1744-
lambda i: self.accept(args[i]),
1745-
)
17461722

17471723
param_spec = callee.param_spec()
17481724
if (
@@ -2633,7 +2609,7 @@ def check_overload_call(
26332609
arg_types = self.infer_arg_types_in_empty_context(args)
26342610
# Step 1: Filter call targets to remove ones where the argument counts don't match
26352611
plausible_targets = self.plausible_overload_call_targets(
2636-
arg_types, arg_kinds, arg_names, callee
2612+
args, arg_types, arg_kinds, arg_names, callee, context
26372613
)
26382614

26392615
# Step 2: If the arguments contain a union, we try performing union math first,
@@ -2751,12 +2727,52 @@ def check_overload_call(
27512727
self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context)
27522728
return result
27532729

2730+
def adjust_generic_callable_params_mapping(
2731+
self,
2732+
callee: CallableType,
2733+
args: list[Expression],
2734+
arg_kinds: list[ArgKind],
2735+
arg_names: Sequence[str | None] | None,
2736+
formal_to_actual: list[list[int]],
2737+
context: Context,
2738+
) -> tuple[CallableType, list[list[int]]]:
2739+
need_refresh = any(
2740+
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
2741+
)
2742+
callee = freshen_function_type_vars(callee)
2743+
callee = self.infer_function_type_arguments_using_context(callee, context)
2744+
if need_refresh:
2745+
# Argument kinds etc. may have changed due to
2746+
# ParamSpec or TypeVarTuple variables being replaced with an arbitrary
2747+
# number of arguments; recalculate actual-to-formal map
2748+
formal_to_actual = map_actuals_to_formals(
2749+
arg_kinds,
2750+
arg_names,
2751+
callee.arg_kinds,
2752+
callee.arg_names,
2753+
lambda i: self.accept(args[i]),
2754+
)
2755+
callee = self.infer_function_type_arguments(
2756+
callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context
2757+
)
2758+
if need_refresh:
2759+
formal_to_actual = map_actuals_to_formals(
2760+
arg_kinds,
2761+
arg_names,
2762+
callee.arg_kinds,
2763+
callee.arg_names,
2764+
lambda i: self.accept(args[i]),
2765+
)
2766+
return callee, formal_to_actual
2767+
27542768
def plausible_overload_call_targets(
27552769
self,
2770+
args: list[Expression],
27562771
arg_types: list[Type],
27572772
arg_kinds: list[ArgKind],
27582773
arg_names: Sequence[str | None] | None,
27592774
overload: Overloaded,
2775+
context: Context,
27602776
) -> list[CallableType]:
27612777
"""Returns all overload call targets that having matching argument counts.
27622778
@@ -2790,6 +2806,10 @@ def has_shape(typ: Type) -> bool:
27902806
formal_to_actual = map_actuals_to_formals(
27912807
arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i]
27922808
)
2809+
if typ.is_generic():
2810+
typ, formal_to_actual = self.adjust_generic_callable_params_mapping(
2811+
typ, args, arg_kinds, arg_names, formal_to_actual, context
2812+
)
27932813

27942814
with self.msg.filter_errors():
27952815
if self.check_argument_count(

test-data/unit/check-parameter-specification.test

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2211,12 +2211,22 @@ from typing import Callable
22112211

22122212
_P = ParamSpec("_P")
22132213

2214-
def run(predicate: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> None:
2214+
def run(predicate: Callable[_P, None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here
22152215
predicate() # E: Too few arguments
22162216
predicate(*args) # E: Too few arguments
22172217
predicate(**kwargs) # E: Too few arguments
22182218
predicate(*args, **kwargs)
22192219

2220+
def fn() -> None: ...
2221+
def fn_args(x: int) -> None: ...
2222+
def fn_posonly(x: int, /) -> None: ...
2223+
2224+
run(fn)
2225+
run(fn_args, 1)
2226+
run(fn_args, x=1)
2227+
run(fn_posonly, 1)
2228+
run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run"
2229+
22202230
[builtins fixtures/paramspec.pyi]
22212231

22222232
[case testRunParamSpecConcatenateInsufficientArgs]
@@ -2225,7 +2235,7 @@ from typing import Callable
22252235

22262236
_P = ParamSpec("_P")
22272237

2228-
def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs: _P.kwargs) -> None:
2238+
def run(predicate: Callable[Concatenate[int, _P], None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here
22292239
predicate() # E: Too few arguments
22302240
predicate(1) # E: Too few arguments
22312241
predicate(1, *args) # E: Too few arguments
@@ -2234,6 +2244,22 @@ def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs
22342244
predicate(*args, **kwargs) # E: Argument 1 has incompatible type "*_P.args"; expected "int"
22352245
predicate(1, *args, **kwargs)
22362246

2247+
def fn() -> None: ...
2248+
def fn_args(x: int, y: str) -> None: ...
2249+
def fn_posonly(x: int, /) -> None: ...
2250+
def fn_posonly_args(x: int, /, y: str) -> None: ...
2251+
2252+
run(fn) # E: Argument 1 to "run" has incompatible type "Callable[[], None]"; expected "Callable[[int], None]"
2253+
run(fn_args, 1, 'a') # E: Too many arguments for "run" \
2254+
# E: Argument 2 to "run" has incompatible type "int"; expected "str"
2255+
run(fn_args, y='a')
2256+
run(fn_args, 'a')
2257+
run(fn_posonly)
2258+
run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run"
2259+
run(fn_posonly_args) # E: Missing positional argument "y" in call to "run"
2260+
run(fn_posonly_args, 'a')
2261+
run(fn_posonly_args, y='a')
2262+
22372263
[builtins fixtures/paramspec.pyi]
22382264

22392265
[case testRunParamSpecConcatenateInsufficientArgsInDecorator]
@@ -2255,3 +2281,37 @@ def decorator(fn: Callable[Concatenate[str, P], None]) -> Callable[P, None]:
22552281
def foo(s: str, s2: str) -> None: ...
22562282

22572283
[builtins fixtures/paramspec.pyi]
2284+
2285+
[case testRunParamSpecOverload]
2286+
from typing_extensions import ParamSpec, Concatenate
2287+
from typing import Callable, overload, NoReturn, TypeVar, Union
2288+
2289+
P = ParamSpec("P")
2290+
T = TypeVar("T")
2291+
2292+
@overload
2293+
def capture(
2294+
sync_fn: Callable[P, NoReturn],
2295+
*args: P.args,
2296+
**kwargs: P.kwargs,
2297+
) -> int: ...
2298+
@overload
2299+
def capture(
2300+
sync_fn: Callable[P, T],
2301+
*args: P.args,
2302+
**kwargs: P.kwargs,
2303+
) -> Union[T, int]: ...
2304+
def capture(
2305+
sync_fn: Callable[P, T],
2306+
*args: P.args,
2307+
**kwargs: P.kwargs,
2308+
) -> Union[T, int]:
2309+
return sync_fn(*args, **kwargs)
2310+
2311+
def fn() -> str: return ''
2312+
def err() -> NoReturn: ...
2313+
2314+
reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.int]"
2315+
reveal_type(capture(err)) # N: Revealed type is "builtins.int"
2316+
2317+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)