Skip to content

Commit 19af306

Browse files
committed
Simplify and fix Mock/AsyncMock return_value handling. Add tests.
1 parent b71f036 commit 19af306

File tree

3 files changed

+23
-25
lines changed

3 files changed

+23
-25
lines changed

tests/test_asyncmock.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,5 @@ async def test_get_attribute_of_unknown_attribute_returns_mock():
419419
m = AsyncMock()
420420
assert isinstance(m.foo, AsyncMock), "Returned object is not an AsyncMock."
421421
assert id(m.foo) == id(m.foo), "Returned object is not the same."
422+
assert await m.foo() is await m.foo(), "Returned object is not the same."
423+
assert isinstance(await m.foo(), AsyncMock), "Returned object is not a Mock."

tests/test_mock.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,13 @@ def assert_call_returns_same_unique_mock():
382382
assert id(m()) == id(mock_result), "Returned mock object is not the same."
383383

384384

385-
def test_get_attribute_of_unknown_attribute_returns_mock():
385+
def test_get_attr_of_unknown_attribute_returns_mock():
386386
"""
387387
Accessing an unknown attribute on the mock object should return another
388388
mock object. Each attribute access should return the same unique mock.
389389
"""
390390
m = Mock()
391391
assert isinstance(m.foo, Mock), "Returned object is not a Mock."
392392
assert id(m.foo) == id(m.foo), "Returned object is not the same."
393+
assert m.foo() is m.foo(), "Returned object is not the same."
394+
assert isinstance(m.foo(), Mock), "Returned object is not a Mock."

umock.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,16 @@ class or instance) that acts as the specification for the mock
162162
for name in self._spec:
163163
# Create a new mock object for each attribute in the spec.
164164
setattr(self, name, Mock())
165-
if return_value:
166-
self.return_value = return_value
165+
self.return_value = return_value
167166
if side_effect:
168167
if type(side_effect) in (str, list, tuple, set, dict):
169168
# If side_effect is an iterable then make it an iterator.
170169
self.side_effect = iter(side_effect)
171170
else:
172171
self.side_effect = side_effect
173-
# The _mock_value is used to ensure the same mock object is returned if
174-
# no return_value or side_effect is specified.
175-
self._mock_value = None
172+
# The return_value is used to ensure the same result is always returned
173+
# when calling the mock object and if no side_effect is specified.
174+
self.return_value = return_value
176175
self.reset_mock()
177176
for key, value in kwargs.items():
178177
setattr(self, key, value)
@@ -314,13 +313,11 @@ def __call__(self, *args, **kwargs):
314313
elif callable(self.side_effect):
315314
return self.side_effect(*args, **kwargs)
316315
raise TypeError("The mock object has an invalid side_effect.")
317-
if hasattr(self, "return_value"):
318-
return self.return_value
319-
else:
320-
# Return a mock object (ensuring it's the same one each time).
321-
if not self._mock_value:
322-
self._mock_value = Mock()
323-
return self._mock_value
316+
# Return the return_value or a mock object (ensuring it's the same one
317+
# each time).
318+
if self.return_value is None:
319+
self.return_value = Mock()
320+
return self.return_value
324321

325322
def __getattr__(self, name):
326323
"""
@@ -423,17 +420,16 @@ class or instance) that acts as the specification for the mock
423420
for name in self._spec:
424421
# Create a new mock object for each attribute in the spec.
425422
setattr(self, name, Mock())
426-
if return_value:
427-
self.return_value = return_value
423+
428424
if side_effect:
429425
if type(side_effect) in (str, list, tuple, set, dict):
430426
# If side_effect is an iterable then make it an iterator.
431427
self.side_effect = iter(side_effect)
432428
else:
433429
self.side_effect = side_effect
434-
# The _mock_value is used to ensure the same mock object is returned if
435-
# no return_value or side_effect is specified.
436-
self._mock_value = None
430+
# The return_value is used to ensure the same result is always returned
431+
# when calling the mock object and if no side_effect is specified.
432+
self.return_value = return_value
437433
self.reset_mock()
438434
for key, value in kwargs.items():
439435
setattr(self, key, value)
@@ -578,13 +574,11 @@ async def __call__(self, *args, **kwargs):
578574
elif callable(self.side_effect):
579575
return self.side_effect(*args, **kwargs)
580576
raise TypeError("The mock object has an invalid side_effect.")
581-
if hasattr(self, "return_value"):
582-
return self.return_value
583-
else:
584-
# Return a mock object (ensuring it's the same one each time).
585-
if not self._mock_value:
586-
self._mock_value = AsyncMock()
587-
return self._mock_value
577+
# Return the return_value or a mock object (ensuring it's the same one
578+
# each time).
579+
if self.return_value is None:
580+
self.return_value = AsyncMock()
581+
return self.return_value
588582

589583
def __getattr__(self, name):
590584
"""

0 commit comments

Comments
 (0)