Skip to content

Commit acfff4b

Browse files
committed
Work in progress (and broken). Resolving target objects for patching. Needs tests and refinement.
1 parent 4489660 commit acfff4b

File tree

2 files changed

+89
-7
lines changed

2 files changed

+89
-7
lines changed

tests/test_patch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
11
"""
22
Tests the patch decorator/context manager.
33
"""
4+
5+
from umock import patch
6+
7+
8+
def test_patch_decorator():
9+
"""
10+
Tests the patch decorator.
11+
"""
12+
13+
@patch("tests.test_patch.some_function", lambda: 42)
14+
def test():
15+
from tests.test_patch import some_function
16+
17+
assert some_function() == 42
18+
19+
test()

umock.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
SOFTWARE.
2626
"""
2727

28+
__all__ = ["Mock", "AsyncMock", "patch"]
29+
2830
#: Attributes of the Mock class that should be handled as "normal" attributes
2931
#: rather than treated as mocked attributes.
3032
_RESERVED_MOCK_ATTRIBUTES = ("side_effect", "return_value")
@@ -298,23 +300,40 @@ class AsyncMock(Mock):
298300

299301
class patch:
300302
"""
301-
patch() acts as a function decorator, class decorator or a context manager.
302-
Inside the body of the function or with statement, the target is patched
303-
with a new object. When the function/with statement exits the patch is
304-
undone.
303+
patch() acts as a function decorator or a context manager. Inside the body
304+
of the function or with statement, the target is patched with a new object.
305+
When the function/with statement exits the patch is undone.
305306
"""
306307

307-
def __init__(self, target, new):
308+
def __init__(self, target, new=None):
308309
"""
309310
Create a new patch object that will replace the target with new.
311+
312+
If the target is a string, it should be in the form
313+
"module.submodule.attribute" or "module.submodule:Class.attribute".
314+
315+
If no new object is provided, a new Mock object is created.
310316
"""
311317
self.target = target
312-
self.new = new
318+
self.new = new or Mock()
319+
320+
def __call__(self, func, *args, **kwargs):
321+
"""
322+
Decorate a function with the patch object.
323+
"""
324+
325+
def wrapper(*args, **kwargs):
326+
with self(self.target, self.new):
327+
return func(*args, **kwargs)
328+
329+
return wrapper
313330

314-
def __enter__(self):
331+
def __enter__(self, target, new):
315332
"""
316333
Replace the target with new.
317334
"""
335+
self.target = resolve_target(self.target)
336+
self.new = new
318337
self._old = getattr(self.target, self.new.__name__, None)
319338
setattr(self.target, self.new.__name__, self.new)
320339
return self.new
@@ -325,3 +344,50 @@ def __exit__(self, exc_type, exc_value, traceback):
325344
"""
326345
setattr(self.target, self.new.__name__, self._old)
327346
return False
347+
348+
349+
def resolve_target(target):
350+
"""
351+
Return the target object. If the target is a string, search for the module
352+
and attribute and return the attribute. Otherwise, return the target as is.
353+
354+
The target as a string should be in the form "module.submodule.attribute"
355+
or "module.submodule:Class.attribute". This function imports the module and
356+
returns the attribute.
357+
358+
"Inspired by" pkgutil.resolve_name in the CPython standard library (but
359+
much simpler/naive).
360+
361+
Will raise an ImportError if the target module cannot be resolved or an
362+
AttributeError if the attribute cannot be found.
363+
"""
364+
if not isinstance(target, str):
365+
return target
366+
if ":" in target:
367+
# There is a colon - a one-step import is all that's needed.
368+
module_name, attribute = target.split(":")
369+
module = __import__(module_name)
370+
parts = attribute.split(".")
371+
else:
372+
# No colon - have to iterate to find the package boundary.
373+
parts = target.split(".")
374+
module_name = parts.pop(0)
375+
# The first part of the target must be a module name.
376+
module = __import__(module_name)
377+
while parts:
378+
# Traverse the parts of the target to find the package boundary.
379+
p = parts.pop(0)
380+
new_module_name = f"{module_name}.{p}"
381+
try:
382+
module = __import__(new_module_name)
383+
parts.pop(0)
384+
module_name = new_module_name
385+
except ImportError:
386+
break
387+
# If we get here, module is the module object we're interested in and
388+
# already imported. The parts list contains the remaining parts of the
389+
# target to be traversed within the module (or an empty list).
390+
result = module
391+
for part in parts:
392+
result = getattr(result, part)
393+
return result

0 commit comments

Comments
 (0)