Skip to content

Commit 7d021c2

Browse files
authored
fix(spy): resolve source to origin of GenericAlias (#143)
Fixes #142
1 parent 2096434 commit 7d021c2

File tree

3 files changed

+54
-20
lines changed

3 files changed

+54
-20
lines changed

decoy/spy_core.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ class BoundArgs(NamedTuple):
2727

2828

2929
class SpyCore:
30-
"""Core spy logic for mimicing a given `source` object.
30+
"""Core spy logic for mimicking a given `source` object.
3131
3232
Arguments:
33-
source: The source object the Spy is mimicing.
33+
source: The source object the Spy is mimicking.
3434
name: The spec's name. If `None`, will be derived from `source`.
3535
Will fallback to a default value.
3636
module_name: The spec's module name. If left unspecified,
@@ -47,6 +47,8 @@ def __init__(
4747
module_name: Union[str, _FROM_SOURCE, None] = FROM_SOURCE,
4848
is_async: bool = False,
4949
) -> None:
50+
source = _resolve_source(source)
51+
5052
self._source = source
5153
self._name = _get_name(source) if name is None else name
5254
self._module_name = (
@@ -139,6 +141,13 @@ def create_child_core(self, name: str, is_async: bool) -> "SpyCore":
139141
)
140142

141143

144+
def _resolve_source(source: Any) -> Any:
145+
"""Resolve the source object, unwrapping any generic aliases."""
146+
origin = inspect.getattr_static(source, "__origin__", None)
147+
148+
return origin if origin is not None else source
149+
150+
142151
def _get_name(source: Any) -> str:
143152
"""Get the name of a source object."""
144153
source_name = getattr(source, "__name__", None) if source is not None else None

tests/fixtures.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,31 @@
11
"""Common test fixtures."""
22
from functools import lru_cache
3-
from typing import Any
3+
from typing import Any, Generic, TypeVar
44

55

66
class SomeClass:
77
"""Testing class."""
88

99
def foo(self, val: str) -> str:
1010
"""Get the foo string."""
11-
...
1211

1312
def bar(self, a: int, b: float, c: str) -> bool:
1413
"""Get the bar bool based on a few inputs."""
15-
...
1614

1715
@staticmethod
1816
def fizzbuzz(hello: str) -> int:
1917
"""Fizz some buzzes."""
20-
...
2118

2219
def do_the_thing(self, *, flag: bool) -> None:
2320
"""Perform a side-effect without a return value."""
24-
...
2521

2622
@property
2723
def primitive_property(self) -> str:
2824
"""Get a primitive computed property."""
29-
...
3025

3126
@lru_cache(maxsize=None)
3227
def some_wrapped_method(self, val: str) -> str:
3328
"""Get a thing through a wrapped method."""
34-
...
3529

3630

3731
class SomeNestedClass:
@@ -41,62 +35,65 @@ class SomeNestedClass:
4135

4236
def foo(self, val: str) -> str:
4337
"""Get the foo string."""
44-
...
4538

4639
@property
4740
def child(self) -> SomeClass:
4841
"""Get the child instance."""
49-
...
5042

5143

5244
class SomeAsyncClass:
5345
"""Async testing class."""
5446

5547
async def foo(self, val: str) -> str:
5648
"""Get the foo string."""
57-
...
5849

5950
async def bar(self, a: int, b: float, c: str) -> bool:
6051
"""Get the bar bool based on a few inputs."""
61-
...
6252

6353
async def do_the_thing(self, *, flag: bool) -> None:
6454
"""Perform a side-effect without a return value."""
65-
...
6655

6756

6857
class SomeAsyncCallableClass:
6958
"""Async callable class."""
7059

7160
async def __call__(self, val: int) -> int:
7261
"""Get an integer."""
73-
...
7462

7563

7664
class SomeCallableClass:
7765
"""Async callable class."""
7866

7967
async def __call__(self, val: int) -> int:
8068
"""Get an integer."""
81-
...
8269

8370

8471
def noop(*args: Any, **kwargs: Any) -> Any:
8572
"""No-op."""
86-
...
8773

8874

8975
def some_func(val: str) -> str:
9076
"""Test function."""
91-
...
9277

9378

9479
async def some_async_func(val: str) -> str:
9580
"""Async test function."""
96-
...
9781

9882

9983
@lru_cache(maxsize=None)
10084
def some_wrapped_func(val: str) -> str:
10185
"""Wrapped test function."""
102-
...
86+
87+
88+
GenericT = TypeVar("GenericT")
89+
90+
91+
class GenericClass(Generic[GenericT]):
92+
"""A generic class definition."""
93+
94+
def hello(self, val: GenericT) -> None:
95+
"""Say hello."""
96+
97+
98+
ConcreteAlias = GenericClass[str]
99+
"""An alias with a generic type specified"""

tests/test_spy_core.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
SomeAsyncCallableClass,
1212
SomeCallableClass,
1313
SomeNestedClass,
14+
GenericClass,
15+
GenericT,
16+
ConcreteAlias,
1417
some_func,
1518
some_async_func,
1619
some_wrapped_func,
@@ -78,6 +81,16 @@ class GetNameSpec(NamedTuple):
7881
expected_name="SomeNestedClass.child.foo",
7982
expected_full_name="tests.fixtures.SomeNestedClass.child.foo",
8083
),
84+
GetNameSpec(
85+
subject=SpyCore(source=GenericClass[int], name=None),
86+
expected_name="GenericClass",
87+
expected_full_name="tests.fixtures.GenericClass",
88+
),
89+
GetNameSpec(
90+
subject=SpyCore(source=ConcreteAlias, name=None),
91+
expected_name="GenericClass",
92+
expected_full_name="tests.fixtures.GenericClass",
93+
),
8194
],
8295
)
8396
def test_get_name(
@@ -226,6 +239,21 @@ class GetSignatureSpec(NamedTuple):
226239
return_annotation=str,
227240
),
228241
),
242+
GetSignatureSpec(
243+
subject=SpyCore(source=ConcreteAlias, name=None).create_child_core(
244+
"hello", is_async=False
245+
),
246+
expected_signature=inspect.Signature(
247+
parameters=[
248+
inspect.Parameter(
249+
name="val",
250+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
251+
annotation=GenericT,
252+
)
253+
],
254+
return_annotation=None,
255+
),
256+
),
229257
],
230258
)
231259
def test_get_signature(

0 commit comments

Comments
 (0)