Skip to content

Commit 59ce440

Browse files
committed
Expand heuristic for inferring Parameters from a Callable
1 parent cc16b25 commit 59ce440

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

mypy/checkexpr.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,6 +2278,22 @@ def get_arg_infer_passes(
22782278
# run(test, 1, 2)
22792279
# we will use `test` for inference, since it will allow to infer also
22802280
# argument *names* for P <: [x: int, y: int].
2281+
if isinstance(p_actual, UnionType):
2282+
new_items = []
2283+
for item in p_actual.items:
2284+
# narrow the union based on some approximations
2285+
p_item = get_proper_type(item)
2286+
if isinstance(p_item, CallableType) or (
2287+
isinstance(p_item, Instance)
2288+
and find_member("__call__", p_item, p_item, is_operator=True)
2289+
is not None
2290+
):
2291+
new_items.append(p_item)
2292+
if len(new_items) == 2:
2293+
break
2294+
2295+
if len(new_items) == 1:
2296+
p_actual = new_items[0]
22812297
if isinstance(p_actual, Instance):
22822298
call_method = find_member("__call__", p_actual, p_actual, is_operator=True)
22832299
if call_method is not None:

test-data/unit/check-parameter-specification.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2560,3 +2560,21 @@ def fn(f: MiddlewareFactory[P]) -> Capture[P]: ...
25602560

25612561
reveal_type(fn(ServerErrorMiddleware)) # N: Revealed type is "__main__.Capture[[handler: Union[builtins.str, None] =, debug: builtins.bool =]]"
25622562
[builtins fixtures/paramspec.pyi]
2563+
2564+
[case testParamSpecInferenceWithAny]
2565+
from typing_extensions import ParamSpec
2566+
from typing import Any, Callable, Union
2567+
2568+
P = ParamSpec("P")
2569+
2570+
def into(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None:
2571+
return None
2572+
2573+
class C:
2574+
def f(self, y: bool = False, *, x: int = 42) -> None:
2575+
return None
2576+
2577+
ex: Union[C, Any] = C()
2578+
2579+
into(ex.f, x=-1)
2580+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)