Skip to content

Commit acb3548

Browse files
committed
Support ParamSpec + functools.partial
1 parent 63995e3 commit acb3548

File tree

4 files changed

+106
-6
lines changed

4 files changed

+106
-6
lines changed

mypy/checkexpr.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,9 +2331,21 @@ def check_argument_count(
23312331
# Positional argument when expecting a keyword argument.
23322332
self.msg.too_many_positional_arguments(callee, context)
23332333
ok = False
2334-
elif callee.param_spec() is not None and not formal_to_actual[i]:
2335-
self.msg.too_few_arguments(callee, context, actual_names)
2336-
ok = False
2334+
elif callee.param_spec() is not None:
2335+
if (
2336+
not formal_to_actual[i]
2337+
and not callee.param_spec_parts_bound[kind == ArgKind.ARG_STAR2]
2338+
and callee.special_sig != "partial"
2339+
):
2340+
self.msg.too_few_arguments(callee, context, actual_names)
2341+
ok = False
2342+
elif (
2343+
formal_to_actual[i]
2344+
and kind == ArgKind.ARG_STAR
2345+
and callee.param_spec_parts_bound[0]
2346+
):
2347+
self.msg.too_many_arguments(callee, context)
2348+
ok = False
23372349
return ok
23382350

23392351
def check_for_extra_actual_arguments(

mypy/plugins/functools.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
140140
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
141141
)
142142
for k in fn_type.arg_kinds
143-
]
143+
],
144+
special_sig="partial",
144145
)
145146
if defaulted.line < 0:
146147
# Make up a line number if we don't have one
@@ -208,6 +209,10 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
208209
arg_kinds=partial_kinds,
209210
arg_names=partial_names,
210211
ret_type=ret_type,
212+
param_spec_parts_bound=(
213+
ArgKind.ARG_STAR in actual_arg_kinds,
214+
ArgKind.ARG_STAR2 in actual_arg_kinds,
215+
),
211216
)
212217

213218
ret = ctx.api.named_generic_type("functools.partial", [ret_type])

mypy/types.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,8 +1773,8 @@ class CallableType(FunctionLike):
17731773
"implicit", # Was this type implicitly generated instead of explicitly
17741774
# specified by the user?
17751775
"special_sig", # Non-None for signatures that require special handling
1776-
# (currently only value is 'dict' for a signature similar to
1777-
# 'dict')
1776+
# (currently only values are 'dict' for a signature similar to
1777+
# 'dict' and 'partial' for a `functools.partial` evaluation)
17781778
"from_type_type", # Was this callable generated by analyzing Type[...]
17791779
# instantiation?
17801780
"bound_args", # Bound type args, mostly unused but may be useful for
@@ -1787,6 +1787,7 @@ class CallableType(FunctionLike):
17871787
# (this is used for error messages)
17881788
"imprecise_arg_kinds",
17891789
"unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable?
1790+
"param_spec_parts_bound", # Hack for functools.partial: allow early binding
17901791
)
17911792

17921793
def __init__(
@@ -1813,6 +1814,7 @@ def __init__(
18131814
from_concatenate: bool = False,
18141815
imprecise_arg_kinds: bool = False,
18151816
unpack_kwargs: bool = False,
1817+
param_spec_parts_bound: tuple[bool, bool] = (False, False),
18161818
) -> None:
18171819
super().__init__(line, column)
18181820
assert len(arg_types) == len(arg_kinds) == len(arg_names)
@@ -1839,6 +1841,7 @@ def __init__(
18391841
self.from_type_type = from_type_type
18401842
self.from_concatenate = from_concatenate
18411843
self.imprecise_arg_kinds = imprecise_arg_kinds
1844+
self.param_spec_parts_bound = param_spec_parts_bound
18421845
if not bound_args:
18431846
bound_args = ()
18441847
self.bound_args = bound_args
@@ -1885,6 +1888,7 @@ def copy_modified(
18851888
from_concatenate: Bogus[bool] = _dummy,
18861889
imprecise_arg_kinds: Bogus[bool] = _dummy,
18871890
unpack_kwargs: Bogus[bool] = _dummy,
1891+
param_spec_parts_bound: Bogus[tuple[bool, bool]] = _dummy,
18881892
) -> CT:
18891893
modified = CallableType(
18901894
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
@@ -1916,6 +1920,11 @@ def copy_modified(
19161920
else self.imprecise_arg_kinds
19171921
),
19181922
unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs,
1923+
param_spec_parts_bound=(
1924+
param_spec_parts_bound
1925+
if param_spec_parts_bound is not _dummy
1926+
else self.param_spec_parts_bound
1927+
),
19191928
)
19201929
# Optimization: Only NewTypes are supported as subtypes since
19211930
# the class is effectively final, so we can use a cast safely.

test-data/unit/check-parameter-specification.test

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,3 +2315,77 @@ reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.in
23152315
reveal_type(capture(err)) # N: Revealed type is "builtins.int"
23162316

23172317
[builtins fixtures/paramspec.pyi]
2318+
2319+
[case testBindPartial]
2320+
from functools import partial
2321+
from typing_extensions import ParamSpec
2322+
from typing import Callable, TypeVar
2323+
2324+
P = ParamSpec("P")
2325+
T = TypeVar("T")
2326+
2327+
def run(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2328+
func2 = partial(func, **kwargs)
2329+
return func2(*args)
2330+
2331+
def run2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2332+
func2 = partial(func, *args)
2333+
return func2(**kwargs)
2334+
2335+
def run3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2336+
func2 = partial(func, *args, **kwargs)
2337+
return func2()
2338+
2339+
def run4(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2340+
func2 = partial(func, *args, **kwargs)
2341+
return func2(**kwargs)
2342+
2343+
def run_bad(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2344+
func2 = partial(func, *args, **kwargs)
2345+
return func2(*args) # E: Too many arguments
2346+
2347+
def run_bad2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2348+
func2 = partial(func, **kwargs)
2349+
return func2(**kwargs) # E: Too few arguments
2350+
2351+
def run_bad3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2352+
func2 = partial(func, *args)
2353+
return func2() # E: Too few arguments
2354+
2355+
[builtins fixtures/paramspec.pyi]
2356+
2357+
[case testBindPartialConcatenate]
2358+
from functools import partial
2359+
from typing_extensions import Concatenate, ParamSpec
2360+
from typing import Callable, TypeVar
2361+
2362+
P = ParamSpec("P")
2363+
T = TypeVar("T")
2364+
2365+
def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2366+
func2 = partial(func, 1, **kwargs)
2367+
return func2(*args)
2368+
2369+
def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2370+
func2 = partial(func, **kwargs)
2371+
return func2(1, *args)
2372+
2373+
def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2374+
func2 = partial(func, 1, *args)
2375+
return func2(**kwargs)
2376+
2377+
def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2378+
func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int"
2379+
return func2(1, **kwargs) # E: Too many arguments
2380+
2381+
def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2382+
func2 = partial(func, 1, *args)
2383+
return func2(1, **kwargs) # E: Too many arguments \
2384+
# E: Argument 1 has incompatible type "int"; expected "P.args"
2385+
2386+
def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2387+
func2 = partial(func, 1, *args)
2388+
return func2(1, **kwargs) # E: Too many arguments \
2389+
# E: Argument 1 has incompatible type "int"; expected "P.args"
2390+
2391+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)