Skip to content

Commit 22bf2fb

Browse files
authored
feat(when): add ContextManager mocking support (#93)
1 parent 806765b commit 22bf2fb

19 files changed

+727
-73
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ poetry run mkdocs serve
6464

6565
The library and documentation will be deployed to PyPI and GitHub Pages, respectively, by CI. To trigger the deploy, cut a new version and push it to GitHub.
6666

67-
Deploy adheres to [semantic versioning][], so care should be taken to bump accurately.
67+
Decoy adheres to [semantic versioning][], so care should be taken to bump accurately.
6868

6969
```bash
7070
# checkout the main branch and pull down latest changes

decoy/__init__.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
"""Decoy stubbing and spying library."""
2-
from typing import Any, Callable, Generic, Optional, cast, overload
3-
4-
from . import matchers, errors, warnings
2+
from typing import Any, Callable, Generic, Optional, Union, cast, overload
3+
4+
from . import errors, matchers, warnings
5+
from .context_managers import (
6+
AsyncContextManager,
7+
ContextManager,
8+
GeneratorContextManager,
9+
)
510
from .core import DecoyCore, StubCore
6-
from .types import ClassT, FuncT, ReturnT
11+
from .types import ClassT, ContextValueT, FuncT, ReturnT
712

813
# ensure decoy does not pollute pytest tracebacks
914
__tracebackhide__ = True
1015

1116

1217
class Decoy:
13-
"""Decoy test double state container."""
18+
"""Decoy mock factory and state container."""
1419

1520
def __init__(self) -> None:
16-
"""Initialize the state container for test doubles and stubs.
21+
"""Initialize a new mock factory.
1722
18-
You should initialize a new Decoy instance for every test. See the
23+
You should create a new Decoy instance for every test. If you use
24+
the `decoy` pytest fixture, this is done automatically. See the
1925
[setup guide](../#setup) for more details.
2026
"""
2127
self._core = DecoyCore()
@@ -111,7 +117,8 @@ def when(
111117
ignoring unspecified arguments.
112118
113119
Returns:
114-
A stub to configure using `then_return`, `then_raise`, or `then_do`.
120+
A stub to configure using `then_return`, `then_raise`, `then_do`, or
121+
`then_enter_with`.
115122
116123
Example:
117124
```python
@@ -137,7 +144,7 @@ def verify(
137144
times: Optional[int] = None,
138145
ignore_extra_args: bool = False,
139146
) -> None:
140-
"""Verify a decoy was called using one or more rehearsals.
147+
"""Verify a mock was called using one or more rehearsals.
141148
142149
See [verification usage guide](../usage/verify/) for more details.
143150
@@ -175,11 +182,11 @@ def test_create_something(decoy: Decoy):
175182
)
176183

177184
def reset(self) -> None:
178-
"""Reset all decoy state.
185+
"""Reset all mock state.
179186
180187
This method should be called after every test to ensure spies and stubs
181-
don't leak between tests. The Decoy fixture provided by the pytest plugin
182-
will do this automatically.
188+
don't leak between tests. The `decoy` fixture provided by the pytest plugin
189+
will call `reset` automatically.
183190
184191
The `reset` method may also trigger warnings if Decoy detects any questionable
185192
mock usage. See [decoy.warnings][] for more details.
@@ -228,5 +235,47 @@ def then_do(self, action: Callable[..., ReturnT]) -> None:
228235
"""
229236
self._core.then_do(action)
230237

238+
@overload
239+
def then_enter_with(
240+
self: "Stub[ContextManager[ContextValueT]]",
241+
value: ContextValueT,
242+
) -> None:
243+
...
244+
245+
@overload
246+
def then_enter_with(
247+
self: "Stub[AsyncContextManager[ContextValueT]]",
248+
value: ContextValueT,
249+
) -> None:
250+
...
251+
252+
@overload
253+
def then_enter_with(
254+
self: "Stub[GeneratorContextManager[ContextValueT]]",
255+
value: ContextValueT,
256+
) -> None:
257+
...
258+
259+
def then_enter_with(
260+
self: Union[
261+
"Stub[GeneratorContextManager[ContextValueT]]",
262+
"Stub[ContextManager[ContextValueT]]",
263+
"Stub[AsyncContextManager[ContextValueT]]",
264+
],
265+
value: ContextValueT,
266+
) -> None:
267+
"""Configure the stub to return a value wrapped in a context manager.
268+
269+
The wrapping context manager is compatible with both the synchronous and
270+
asynchronous context manager interfaces.
271+
272+
See the [context manager usage guide](../advanced/context-managers/)
273+
for more details.
274+
275+
Arguments:
276+
value: A return value to wrap in a ContextManager.
277+
"""
278+
self._core.then_enter_with(value)
279+
231280

232281
__all__ = ["Decoy", "Stub", "matchers", "warnings", "errors"]

decoy/call_handler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from typing import Any
33

44
from .call_stack import CallStack
5-
from .stub_store import StubStore
5+
from .context_managers import ContextWrapper
66
from .spy_calls import SpyCall
7+
from .stub_store import StubStore
78

89

910
class CallHandler:
@@ -25,4 +26,7 @@ def handle(self, call: SpyCall) -> Any:
2526
if behavior.action:
2627
return behavior.action(*call.args, **call.kwargs)
2728

29+
if behavior.context_value:
30+
return ContextWrapper(behavior.context_value)
31+
2832
return behavior.return_value

decoy/context_managers.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Wrappers around contextlib types and fallbacks."""
2+
import contextlib
3+
from typing import Any, AsyncContextManager, ContextManager, Generic, TypeVar
4+
5+
GeneratorContextManager = contextlib._GeneratorContextManager
6+
7+
_EnterT = TypeVar("_EnterT")
8+
9+
10+
class ContextWrapper(
11+
ContextManager[_EnterT],
12+
AsyncContextManager[_EnterT],
13+
Generic[_EnterT],
14+
):
15+
"""A simple, do-nothing ContextManager that wraps a given value.
16+
17+
Adapted from `contextlib.nullcontext` to ensure support across
18+
all Python versions.
19+
"""
20+
21+
def __init__(self, enter_result: _EnterT) -> None:
22+
self._enter_result = enter_result
23+
24+
def __enter__(self) -> _EnterT:
25+
"""Return the wrapped value."""
26+
return self._enter_result
27+
28+
def __exit__(self, *args: Any, **kwargs: Any) -> Any:
29+
"""No-op on exit."""
30+
pass
31+
32+
async def __aenter__(self) -> _EnterT:
33+
"""Return the wrapped value."""
34+
return self._enter_result
35+
36+
async def __aexit__(self, *args: Any, **kwargs: Any) -> Any:
37+
"""No-op on exit."""
38+
pass
39+
40+
41+
__all__ = [
42+
"AsyncContextManager",
43+
"GeneratorContextManager",
44+
"ContextManager",
45+
"ContextWrapper",
46+
]

decoy/core.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Decoy implementation logic."""
22
from typing import Any, Callable, Optional
33

4-
from .spy import SpyConfig, SpyFactory, create_spy as default_create_spy
5-
from .spy_calls import WhenRehearsal
6-
from .call_stack import CallStack
7-
from .stub_store import StubStore, StubBehavior
84
from .call_handler import CallHandler
5+
from .call_stack import CallStack
6+
from .spy import SpyConfig, SpyFactory
7+
from .spy import create_spy as default_create_spy
8+
from .spy_calls import WhenRehearsal
9+
from .stub_store import StubBehavior, StubStore
10+
from .types import ContextValueT, ReturnT
911
from .verifier import Verifier
1012
from .warning_checker import WarningChecker
11-
from .types import ReturnT
1213

1314
# ensure decoy.core does not pollute Pytest tracebacks
1415
__tracebackhide__ = True
@@ -115,3 +116,10 @@ def then_do(self, action: Callable[..., ReturnT]) -> None:
115116
rehearsal=self._rehearsal,
116117
behavior=StubBehavior(action=action),
117118
)
119+
120+
def then_enter_with(self, value: ContextValueT) -> None:
121+
"""Set the stub to return a ContextManager wrapped value."""
122+
self._stub_store.add(
123+
rehearsal=self._rehearsal,
124+
behavior=StubBehavior(context_value=value),
125+
)

decoy/pytest_plugin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
fixture without modifying any other pytest behavior. Its usage is optional
55
but highly recommended.
66
"""
7-
import pytest
87
from typing import Iterable
8+
9+
import pytest
10+
911
from decoy import Decoy
1012

1113

12-
@pytest.fixture
14+
@pytest.fixture()
1315
def decoy() -> Iterable[Decoy]:
1416
"""Get a [decoy.Decoy][] container and tear it down after the test.
1517

decoy/spy.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66
from inspect import getattr_static, isclass, iscoroutinefunction, isfunction, signature
77
from functools import partial
88
from warnings import warn
9-
from typing import get_type_hints, Any, Callable, Dict, NamedTuple, Optional
9+
from types import TracebackType
10+
from typing import (
11+
cast,
12+
get_type_hints,
13+
Any,
14+
Callable,
15+
ContextManager,
16+
Dict,
17+
NamedTuple,
18+
Optional,
19+
Type,
20+
)
1021

1122
from .spy_calls import SpyCall
1223
from .warnings import IncorrectCallWarning
@@ -37,7 +48,7 @@ def _get_type_hints(obj: Any) -> Dict[str, Any]:
3748
return {}
3849

3950

40-
class BaseSpy:
51+
class BaseSpy(ContextManager[Any]):
4152
"""Spy object base class.
4253
4354
- Pretends to be another class, if another class is given as a spec
@@ -84,25 +95,35 @@ def __class__(self) -> Any:
8495

8596
return type(self)
8697

87-
def _call(self, *args: Any, **kwargs: Any) -> Any:
88-
spy_id = id(self)
89-
spy_name = (
90-
self._name
91-
if self._name
92-
else f"{type(self).__module__}.{type(self).__qualname__}"
93-
)
98+
def __enter__(self) -> Any:
99+
"""Allow a spy to be used as a context manager."""
100+
enter_spy = self._get_or_create_child_spy("__enter__")
101+
return enter_spy()
94102

95-
if hasattr(self, "__signature__"):
96-
try:
97-
bound_args = self.__signature__.bind(*args, **kwargs)
98-
except TypeError as e:
99-
# stacklevel: 3 ensures warning is linked to call location
100-
warn(IncorrectCallWarning(e), stacklevel=3)
101-
else:
102-
args = bound_args.args
103-
kwargs = bound_args.kwargs
104-
105-
return self._handle_call(SpyCall(spy_id, spy_name, args, kwargs))
103+
def __exit__(
104+
self,
105+
exc_type: Optional[Type[BaseException]],
106+
exc_value: Optional[BaseException],
107+
traceback: Optional[TracebackType],
108+
) -> Optional[bool]:
109+
"""Allow a spy to be used as a context manager."""
110+
exit_spy = self._get_or_create_child_spy("__exit__")
111+
return cast(Optional[bool], exit_spy(exc_type, exc_value, traceback))
112+
113+
async def __aenter__(self) -> Any:
114+
"""Allow a spy to be used as an async context manager."""
115+
enter_spy = self._get_or_create_child_spy("__aenter__")
116+
return await enter_spy()
117+
118+
async def __aexit__(
119+
self,
120+
exc_type: Optional[Type[BaseException]],
121+
exc_value: Optional[BaseException],
122+
traceback: Optional[TracebackType],
123+
) -> Optional[bool]:
124+
"""Allow a spy to be used as a context manager."""
125+
exit_spy = self._get_or_create_child_spy("__aexit__")
126+
return cast(Optional[bool], await exit_spy(exc_type, exc_value, traceback))
106127

107128
def __repr__(self) -> str:
108129
"""Get a helpful string representation of the spy."""
@@ -118,14 +139,15 @@ def __repr__(self) -> str:
118139
return "<Decoy mock>"
119140

120141
def __getattr__(self, name: str) -> Any:
121-
"""Get a property of the spy.
122-
123-
Lazily constructs child spies, basing them on type hints if available.
124-
"""
142+
"""Get a property of the spy, always returning a child spy."""
125143
# do not attempt to mock magic methods
126144
if name.startswith("__") and name.endswith("__"):
127145
return super().__getattribute__(name)
128146

147+
return self._get_or_create_child_spy(name)
148+
149+
def _get_or_create_child_spy(self, name: str) -> Any:
150+
"""Lazily construct a child spy, basing it on type hints if available."""
129151
# return previously constructed (and cached) child spies
130152
if name in self._spy_children:
131153
return self._spy_children[name]
@@ -167,6 +189,26 @@ def __getattr__(self, name: str) -> Any:
167189

168190
return spy
169191

192+
def _call(self, *args: Any, **kwargs: Any) -> Any:
193+
spy_id = id(self)
194+
spy_name = (
195+
self._name
196+
if self._name
197+
else f"{type(self).__module__}.{type(self).__qualname__}"
198+
)
199+
200+
if hasattr(self, "__signature__"):
201+
try:
202+
bound_args = self.__signature__.bind(*args, **kwargs)
203+
except TypeError as e:
204+
# stacklevel: 3 ensures warning is linked to call location
205+
warn(IncorrectCallWarning(e), stacklevel=3)
206+
else:
207+
args = bound_args.args
208+
kwargs = bound_args.kwargs
209+
210+
return self._handle_call(SpyCall(spy_id, spy_name, args, kwargs))
211+
170212

171213
class Spy(BaseSpy):
172214
"""An object that records all calls made to itself and its children."""

decoy/stub_store.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class StubBehavior(NamedTuple):
88
"""A recorded stub behavior."""
99

1010
return_value: Optional[Any] = None
11+
context_value: Optional[Any] = None
1112
error: Optional[Exception] = None
1213
action: Optional[Callable[..., Any]] = None
1314
once: bool = False

decoy/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@
99

1010
ReturnT = TypeVar("ReturnT")
1111
"""The return type of a given call."""
12+
13+
ContextValueT = TypeVar("ContextValueT")
14+
"""A context manager value returned by a stub."""

0 commit comments

Comments
 (0)