Skip to content

Commit 6a73b14

Browse files
authored
fix(spy): ensure class signature uses __call__ (#120)
1 parent 9f804e6 commit 6a73b14

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

decoy/spec.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ def get_full_name(self) -> str:
7171

7272
def get_signature(self) -> Optional[inspect.Signature]:
7373
"""Get the Spec's signature, if Spec represents a callable."""
74+
source = self._get_source()
75+
7476
try:
75-
return inspect.signature(self._source) # type: ignore[arg-type]
77+
return inspect.signature(source)
7678
except TypeError:
7779
return None
7880

@@ -82,18 +84,12 @@ def get_class_type(self) -> Optional[Type[Any]]:
8284

8385
def get_is_async(self) -> bool:
8486
"""Get whether the Spec represents an async. callable."""
85-
source = self._source
87+
source = self._get_source()
8688

8789
# `iscoroutinefunction` does not work for `partial` on Python < 3.8
8890
if isinstance(source, functools.partial):
8991
source = source.func
9092

91-
# check if spec source is a class with a __call__ method
92-
elif inspect.isclass(source):
93-
call_method = inspect.getattr_static(source, "__call__", None)
94-
if inspect.isfunction(call_method):
95-
source = call_method
96-
9793
return inspect.iscoroutinefunction(source)
9894

9995
def bind_args(self, *args: Any, **kwargs: Any) -> BoundArgs:
@@ -141,6 +137,19 @@ def get_child_spec(self, name: str) -> "Spec":
141137

142138
return Spec(source=child_source, name=child_name, module_name=self._module_name)
143139

140+
def _get_source(self) -> Any:
141+
source = self._source
142+
143+
# check if spec source is a class with a __call__ method
144+
if inspect.isclass(source):
145+
call_method = inspect.getattr_static(source, "__call__", None)
146+
if inspect.isfunction(call_method):
147+
# consume the `self` argument of the method to ensure proper
148+
# signature reporting by wrapping it in a partial
149+
source = functools.partial(call_method, None)
150+
151+
return source
152+
144153

145154
def _get_type_hints(obj: Any) -> Dict[str, Any]:
146155
"""Get type hints for an object, if possible.

tests/common.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,15 @@ async def do_the_thing(self, *, flag: bool) -> None:
5757
class SomeAsyncCallableClass:
5858
"""Async callable class."""
5959

60-
async def __call__(self) -> int:
60+
async def __call__(self, val: int) -> int:
61+
"""Get an integer."""
62+
...
63+
64+
65+
class SomeCallableClass:
66+
"""Async callable class."""
67+
68+
async def __call__(self, val: int) -> int:
6169
"""Get an integer."""
6270
...
6371

tests/test_spec.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
SomeClass,
1010
SomeAsyncClass,
1111
SomeAsyncCallableClass,
12+
SomeCallableClass,
1213
SomeNestedClass,
1314
some_func,
1415
some_async_func,
@@ -164,6 +165,19 @@ class GetSignatureSpec(NamedTuple):
164165
return_annotation=int,
165166
),
166167
),
168+
GetSignatureSpec(
169+
subject=Spec(source=SomeCallableClass, name=None),
170+
expected_signature=inspect.Signature(
171+
parameters=[
172+
inspect.Parameter(
173+
name="val",
174+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
175+
annotation=int,
176+
)
177+
],
178+
return_annotation=int,
179+
),
180+
),
167181
],
168182
)
169183
def test_get_signature(

0 commit comments

Comments
 (0)