Skip to content

Commit 55ef599

Browse files
committed
support Callable / callable Protocols in suggest decorator unwarpping
1 parent a3aac71 commit 55ef599

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

mypy/suggestions.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,18 @@ def is_implicit_any(typ: Type) -> bool:
230230
return isinstance(typ, AnyType) and not is_explicit_any(typ)
231231

232232

233+
def _arg_accepts_function(typ: ProperType) -> bool:
234+
return (
235+
# TypeVar / Callable
236+
isinstance(typ, (TypeVarType, CallableType))
237+
or
238+
# Protocol with __call__
239+
isinstance(typ, Instance)
240+
and typ.type.is_protocol
241+
and isinstance(typ.type.get_method("__call__"), FuncDef)
242+
)
243+
244+
233245
class SuggestionEngine:
234246
"""Engine for finding call sites and suggesting signatures."""
235247

@@ -659,7 +671,7 @@ def extract_from_decorator(self, node: Decorator) -> FuncDef | None:
659671
for ct in typ.items:
660672
if not (
661673
len(ct.arg_types) == 1
662-
and isinstance(ct.arg_types[0], TypeVarType)
674+
and _arg_accepts_function(get_proper_type(ct.arg_types[0]))
663675
and ct.arg_types[0] == ct.ret_type
664676
):
665677
return None

test-data/unit/fine-grained-suggest.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,38 @@ foo3('hello hello')
651651
(str) -> str
652652
==
653653

654+
[case testSuggestInferFuncDecorator6]
655+
# suggest: foo.f
656+
[file foo.py]
657+
from __future__ import annotations
658+
659+
from typing import Callable, Protocol, TypeVar
660+
from typing_extensions import ParamSpec
661+
662+
P = ParamSpec('P')
663+
R = TypeVar('R')
664+
R_co = TypeVar('R_co', covariant=True)
665+
666+
class Proto(Protocol[P, R_co]):
667+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co: ...
668+
669+
def dec1(f: Callable[P, R]) -> Callable[P, R]: ...
670+
def dec2(f: Callable[..., R]) -> Callable[..., R]: ...
671+
def dec3(f: Proto[P, R_co]) -> Proto[P, R_co]: ...
672+
673+
@dec1
674+
@dec2
675+
@dec3
676+
def f(x):
677+
return x
678+
679+
f('hi')
680+
681+
[builtins fixtures/isinstancelist.pyi]
682+
[out]
683+
(str) -> str
684+
==
685+
654686
[case testSuggestFlexAny1]
655687
# suggest: --flex-any=0.4 m.foo
656688
# suggest: --flex-any=0.7 m.foo

0 commit comments

Comments
 (0)