Skip to content
Merged
33 changes: 31 additions & 2 deletions Lib/test/test_unittest/testmock/testpatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,35 @@ def test_stop_idempotent(self):
self.assertIsNone(patcher.stop())


def test_exit_idempotent(self):
patcher = patch(foo_name, 'bar', 3)
with patcher:
patcher.stop()


def test_second_start_failure(self):
patcher = patch(foo_name, 'bar', 3)
patcher.start()
try:
self.assertRaises(RuntimeError, patcher.start)
finally:
patcher.stop()


def test_second_enter_failure(self):
patcher = patch(foo_name, 'bar', 3)
with patcher:
self.assertRaises(RuntimeError, patcher.start)


def test_second_start_after_stop(self):
patcher = patch(foo_name, 'bar', 3)
patcher.start()
patcher.stop()
patcher.start()
patcher.stop()


def test_patchobject_start_stop(self):
original = something
patcher = patch.object(PTModule, 'something', 'foo')
Expand Down Expand Up @@ -1098,7 +1127,7 @@ def test_new_callable_patch(self):

self.assertIsNot(m1, m2)
for mock in m1, m2:
self.assertNotCallable(m1)
self.assertNotCallable(mock)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice spot!



def test_new_callable_patch_object(self):
Expand All @@ -1111,7 +1140,7 @@ def test_new_callable_patch_object(self):

self.assertIsNot(m1, m2)
for mock in m1, m2:
self.assertNotCallable(m1)
self.assertNotCallable(mock)


def test_new_callable_keyword_arguments(self):
Expand Down
83 changes: 65 additions & 18 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


import asyncio
from collections import namedtuple
import contextlib
import io
import inspect
Expand Down Expand Up @@ -1320,6 +1321,9 @@ def _check_spec_arg_typos(kwargs_to_check):
)


_PatchContext = namedtuple("_PatchContext", "exit_stack is_local original target")


class _patch(object):

attribute_name = None
Expand Down Expand Up @@ -1360,6 +1364,7 @@ def __init__(
self.autospec = autospec
self.kwargs = kwargs
self.additional_patchers = []
self._context = None


def copy(self):
Expand Down Expand Up @@ -1469,13 +1474,51 @@ def get_original(self):
)
return original, local

@property
def is_started(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, those were writable attributes and now it's no more the case. Could there be some code in the wild assuming so? (for instance pytest which makes quite hacky things, though I don't know if they do hacky things with this specific part of CPython).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so... I have committed property setters just in case. Will write tests if needed when we figure out what to do with temp_original: do we preserve it and somehow deprecate or anything else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed them. A bit ugly but anyway these setters exist not for an intended usecase but for backwards compatibility only

return self._context is not None

@property
def is_local(self):
return self._context.is_local

@is_local.setter
def is_local(self, value):
self._context.is_local = value

@property
def original(self):
return self._context.original

@original.setter
def original(self, value):
self._context.original = value

@property
def target(self):
return self._context.target

@target.setter
def target(self, value):
self._context.target = value

@property
def temp_original(self): # backwards compatibility
return self.original

@temp_original.setter
def temp_original(self, value): # backwards compatibility
self.original = value

def __enter__(self):
"""Perform the patch."""
if self.is_started:
raise RuntimeError("Patch is already started")

new, spec, spec_set = self.new, self.spec, self.spec_set
autospec, kwargs = self.autospec, self.kwargs
new_callable = self.new_callable
self.target = self.getter()
target = self.getter()

# normalise False to None
if spec is False:
Expand All @@ -1491,7 +1534,7 @@ def __enter__(self):
spec_set not in (True, None)):
raise TypeError("Can't provide explicit spec_set *and* spec or autospec")

original, local = self.get_original()
original, is_local = self.get_original()

if new is DEFAULT and autospec is None:
inherit = False
Expand Down Expand Up @@ -1579,17 +1622,17 @@ def __enter__(self):
if autospec is True:
autospec = original

if _is_instance_mock(self.target):
if _is_instance_mock(target):
raise InvalidSpecError(
f'Cannot autospec attr {self.attribute!r} as the patch '
f'target has already been mocked out. '
f'[target={self.target!r}, attr={autospec!r}]')
f'[target={target!r}, attr={autospec!r}]')
if _is_instance_mock(autospec):
target_name = getattr(self.target, '__name__', self.target)
target_name = getattr(target, '__name__', target)
raise InvalidSpecError(
f'Cannot autospec attr {self.attribute!r} from target '
f'{target_name!r} as it has already been mocked out. '
f'[target={self.target!r}, attr={autospec!r}]')
f'[target={target!r}, attr={autospec!r}]')

new = create_autospec(autospec, spec_set=spec_set,
_name=self.attribute, **kwargs)
Expand All @@ -1600,17 +1643,21 @@ def __enter__(self):

new_attr = new

self.temp_original = original
self.is_local = local
self._exit_stack = contextlib.ExitStack()
exit_stack = contextlib.ExitStack()
self._context = _PatchContext(
exit_stack=exit_stack,
is_local=is_local,
original=original,
target=self.getter(),
)
try:
setattr(self.target, self.attribute, new_attr)
if self.attribute_name is not None:
extra_args = {}
if self.new is DEFAULT:
extra_args[self.attribute_name] = new
for patching in self.additional_patchers:
arg = self._exit_stack.enter_context(patching)
arg = exit_stack.enter_context(patching)
if patching.new is DEFAULT:
extra_args.update(arg)
return extra_args
Expand All @@ -1622,22 +1669,22 @@ def __enter__(self):

def __exit__(self, *exc_info):
"""Undo the patch."""
if self.is_local and self.temp_original is not DEFAULT:
setattr(self.target, self.attribute, self.temp_original)
if not self.is_started:
return

if self.is_local and self.original is not DEFAULT:
setattr(self.target, self.attribute, self.original)
else:
delattr(self.target, self.attribute)
if not self.create and (not hasattr(self.target, self.attribute) or
self.attribute in ('__doc__', '__module__',
'__defaults__', '__annotations__',
'__kwdefaults__')):
# needed for proxy objects like django settings
setattr(self.target, self.attribute, self.temp_original)
setattr(self.target, self.attribute, self.original)

del self.temp_original
del self.is_local
del self.target
exit_stack = self._exit_stack
del self._exit_stack
exit_stack = self._context.exit_stack
self._context = None
return exit_stack.__exit__(*exc_info)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Limit starting a patcher (from :func:`unittest.mock.patch`,
:func:`unittest.mock.patch.object` or :func:`unittest.mock.patch.dict`) more than
once without stopping it.
Loading