Skip to content

Commit 512a722

Browse files
committed
Support ParamSpec + functools.partial
1 parent 1903402 commit 512a722

File tree

4 files changed

+105
-5
lines changed

4 files changed

+105
-5
lines changed

mypy/checkexpr.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,9 +2366,21 @@ def check_argument_count(
23662366
# Positional argument when expecting a keyword argument.
23672367
self.msg.too_many_positional_arguments(callee, context)
23682368
ok = False
2369-
elif callee.param_spec() is not None and not formal_to_actual[i]:
2370-
self.msg.too_few_arguments(callee, context, actual_names)
2371-
ok = False
2369+
elif callee.param_spec() is not None:
2370+
if (
2371+
not formal_to_actual[i]
2372+
and not callee.param_spec_parts_bound[kind == ArgKind.ARG_STAR2]
2373+
and callee.special_sig != "partial"
2374+
):
2375+
self.msg.too_few_arguments(callee, context, actual_names)
2376+
ok = False
2377+
elif (
2378+
formal_to_actual[i]
2379+
and kind == ArgKind.ARG_STAR
2380+
and callee.param_spec_parts_bound[0]
2381+
):
2382+
self.msg.too_many_arguments(callee, context)
2383+
ok = False
23722384
return ok
23732385

23742386
def check_for_extra_actual_arguments(

mypy/plugins/functools.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
161161
for k in fn_type.arg_kinds
162162
],
163163
ret_type=ret_type,
164+
special_sig="partial",
164165
)
165166
if defaulted.line < 0:
166167
# Make up a line number if we don't have one
@@ -267,6 +268,10 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
267268
arg_kinds=partial_kinds,
268269
arg_names=partial_names,
269270
ret_type=ret_type,
271+
param_spec_parts_bound=(
272+
ArgKind.ARG_STAR in actual_arg_kinds,
273+
ArgKind.ARG_STAR2 in actual_arg_kinds,
274+
),
270275
)
271276

272277
ret = ctx.api.named_generic_type(PARTIAL, [ret_type])

mypy/types.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,8 +1811,8 @@ class CallableType(FunctionLike):
18111811
"implicit", # Was this type implicitly generated instead of explicitly
18121812
# specified by the user?
18131813
"special_sig", # Non-None for signatures that require special handling
1814-
# (currently only value is 'dict' for a signature similar to
1815-
# 'dict')
1814+
# (currently only values are 'dict' for a signature similar to
1815+
# 'dict' and 'partial' for a `functools.partial` evaluation)
18161816
"from_type_type", # Was this callable generated by analyzing Type[...]
18171817
# instantiation?
18181818
"bound_args", # Bound type args, mostly unused but may be useful for
@@ -1825,6 +1825,7 @@ class CallableType(FunctionLike):
18251825
# (this is used for error messages)
18261826
"imprecise_arg_kinds",
18271827
"unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable?
1828+
"param_spec_parts_bound", # Hack for functools.partial: allow early binding
18281829
)
18291830

18301831
def __init__(
@@ -1851,6 +1852,7 @@ def __init__(
18511852
from_concatenate: bool = False,
18521853
imprecise_arg_kinds: bool = False,
18531854
unpack_kwargs: bool = False,
1855+
param_spec_parts_bound: tuple[bool, bool] = (False, False),
18541856
) -> None:
18551857
super().__init__(line, column)
18561858
assert len(arg_types) == len(arg_kinds) == len(arg_names)
@@ -1877,6 +1879,7 @@ def __init__(
18771879
self.from_type_type = from_type_type
18781880
self.from_concatenate = from_concatenate
18791881
self.imprecise_arg_kinds = imprecise_arg_kinds
1882+
self.param_spec_parts_bound = param_spec_parts_bound
18801883
if not bound_args:
18811884
bound_args = ()
18821885
self.bound_args = bound_args
@@ -1923,6 +1926,7 @@ def copy_modified(
19231926
from_concatenate: Bogus[bool] = _dummy,
19241927
imprecise_arg_kinds: Bogus[bool] = _dummy,
19251928
unpack_kwargs: Bogus[bool] = _dummy,
1929+
param_spec_parts_bound: Bogus[tuple[bool, bool]] = _dummy,
19261930
) -> CT:
19271931
modified = CallableType(
19281932
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
@@ -1954,6 +1958,11 @@ def copy_modified(
19541958
else self.imprecise_arg_kinds
19551959
),
19561960
unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs,
1961+
param_spec_parts_bound=(
1962+
param_spec_parts_bound
1963+
if param_spec_parts_bound is not _dummy
1964+
else self.param_spec_parts_bound
1965+
),
19571966
)
19581967
# Optimization: Only NewTypes are supported as subtypes since
19591968
# the class is effectively final, so we can use a cast safely.

test-data/unit/check-parameter-specification.test

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,3 +2294,77 @@ reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.in
22942294
reveal_type(capture(err)) # N: Revealed type is "builtins.int"
22952295

22962296
[builtins fixtures/paramspec.pyi]
2297+
2298+
[case testBindPartial]
2299+
from functools import partial
2300+
from typing_extensions import ParamSpec
2301+
from typing import Callable, TypeVar
2302+
2303+
P = ParamSpec("P")
2304+
T = TypeVar("T")
2305+
2306+
def run(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2307+
func2 = partial(func, **kwargs)
2308+
return func2(*args)
2309+
2310+
def run2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2311+
func2 = partial(func, *args)
2312+
return func2(**kwargs)
2313+
2314+
def run3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2315+
func2 = partial(func, *args, **kwargs)
2316+
return func2()
2317+
2318+
def run4(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2319+
func2 = partial(func, *args, **kwargs)
2320+
return func2(**kwargs)
2321+
2322+
def run_bad(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2323+
func2 = partial(func, *args, **kwargs)
2324+
return func2(*args) # E: Too many arguments
2325+
2326+
def run_bad2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2327+
func2 = partial(func, **kwargs)
2328+
return func2(**kwargs) # E: Too few arguments
2329+
2330+
def run_bad3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2331+
func2 = partial(func, *args)
2332+
return func2() # E: Too few arguments
2333+
2334+
[builtins fixtures/paramspec.pyi]
2335+
2336+
[case testBindPartialConcatenate]
2337+
from functools import partial
2338+
from typing_extensions import Concatenate, ParamSpec
2339+
from typing import Callable, TypeVar
2340+
2341+
P = ParamSpec("P")
2342+
T = TypeVar("T")
2343+
2344+
def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2345+
func2 = partial(func, 1, **kwargs)
2346+
return func2(*args)
2347+
2348+
def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2349+
func2 = partial(func, **kwargs)
2350+
return func2(1, *args)
2351+
2352+
def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2353+
func2 = partial(func, 1, *args)
2354+
return func2(**kwargs)
2355+
2356+
def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2357+
func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int"
2358+
return func2(1, **kwargs) # E: Too many arguments
2359+
2360+
def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2361+
func2 = partial(func, 1, *args)
2362+
return func2(1, **kwargs) # E: Too many arguments \
2363+
# E: Argument 1 has incompatible type "int"; expected "P.args"
2364+
2365+
def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2366+
func2 = partial(func, 1, *args)
2367+
return func2(1, **kwargs) # E: Too many arguments \
2368+
# E: Argument 1 has incompatible type "int"; expected "P.args"
2369+
2370+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)