Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ also tracks function/method calls, return values and exceptions raised.
The object returned by ``mocker.spy`` is a ``MagicMock`` object, so all standard checking functions
are available (like ``assert_called_once_with`` or ``call_count`` in the examples above).

In addition, spy objects contain two extra attributes:
In addition, spy objects contain four extra attributes:

* ``spy_return``: contains the last returned value of the spied function.
* ``spy_return_iter``: contains a duplicate of the last returned value of the spied function if the value was an iterator. Uses `tee <https://docs.python.org/3/library/itertools.html#itertools.tee>`__) to duplicate the iterator.
* ``spy_return_iter``: contains a duplicate of the last returned value of the spied function if the value was an iterator and spy was created using ``.spy(..., duplicate_iterators)``. Uses `tee <https://docs.python.org/3/library/itertools.html#itertools.tee>`__) to duplicate the iterator.
* ``spy_return_list``: contains a list of all returned values of the spied function (new in ``3.13``).
* ``spy_exception``: contain the last exception value raised by the spied function/method when
it was last called, or ``None`` if no exception was raised.
Expand Down
7 changes: 5 additions & 2 deletions src/pytest_mock/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,16 @@ def stop(self, mock: unittest.mock.MagicMock) -> None:
"""
self._mock_cache.remove(mock)

def spy(self, obj: object, name: str) -> MockType:
def spy(
self, obj: object, name: str, duplicate_iterators: bool = False
) -> MockType:
"""
Create a spy of method. It will run method normally, but it is now
possible to use `mock` call features with it, like call count.

:param obj: An object.
:param name: A method in object.
:param duplicate_iterators: Whether to keep a copy of the returned iterator in `spy_return_iter`.
:return: Spy object.
"""
method = getattr(obj, name)
Expand All @@ -177,7 +180,7 @@ def wrapper(*args, **kwargs):
spy_obj.spy_exception = e
raise
else:
if isinstance(r, Iterator):
if duplicate_iterators and isinstance(r, Iterator):
r, duplicated_iterator = itertools.tee(r, 2)
spy_obj.spy_return_iter = duplicated_iterator
else:
Expand Down
31 changes: 26 additions & 5 deletions tests/test_pytest_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,15 @@ def __call__(self, x):


@pytest.mark.parametrize("iterator", [(i for i in range(3)), iter([0, 1, 2])])
def test_spy_return_iter(mocker: MockerFixture, iterator: Iterator[int]) -> None:
def test_spy_return_iter_duplicates_iterator_when_enabled(
mocker: MockerFixture, iterator: Iterator[int]
) -> None:
class Foo:
def bar(self) -> Iterator[int]:
return iterator

foo = Foo()
spy = mocker.spy(foo, "bar")
spy = mocker.spy(foo, "bar", duplicate_iterators=True)
result = list(foo.bar())

assert result == [0, 1, 2]
Expand All @@ -558,16 +560,35 @@ def bar(self) -> Iterator[int]:
assert isinstance(return_value, Iterator)


@pytest.mark.parametrize("iterator", [(i for i in range(3)), iter([0, 1, 2])])
def test_spy_return_iter_is_not_set_when_disabled(
mocker: MockerFixture, iterator: Iterator[int]
) -> None:
class Foo:
def bar(self) -> Iterator[int]:
return iterator

foo = Foo()
spy = mocker.spy(foo, "bar", duplicate_iterators=False)
result = list(foo.bar())

assert result == [0, 1, 2]
assert spy.spy_return is not None
assert spy.spy_return_iter is None
[return_value] = spy.spy_return_list
assert isinstance(return_value, Iterator)


@pytest.mark.parametrize("iterable", [(0, 1, 2), [0, 1, 2], range(3)])
def test_spy_return_iter_ignore_plain_iterable(
def test_spy_return_iter_ignores_plain_iterable(
mocker: MockerFixture, iterable: Iterable[int]
) -> None:
class Foo:
def bar(self) -> Iterable[int]:
return iterable

foo = Foo()
spy = mocker.spy(foo, "bar")
spy = mocker.spy(foo, "bar", duplicate_iterators=True)
result = foo.bar()

assert result == iterable
Expand All @@ -587,7 +608,7 @@ def bar(self) -> Any:
return self.iterables.pop(0)

foo = Foo()
spy = mocker.spy(foo, "bar")
spy = mocker.spy(foo, "bar", duplicate_iterators=True)
result_iterator = list(foo.bar())

assert result_iterator == [0, 1, 2]
Expand Down
Loading