25
25
SOFTWARE.
26
26
"""
27
27
28
+ __all__ = ["Mock" , "AsyncMock" , "patch" ]
29
+
28
30
#: Attributes of the Mock class that should be handled as "normal" attributes
29
31
#: rather than treated as mocked attributes.
30
32
_RESERVED_MOCK_ATTRIBUTES = ("side_effect" , "return_value" )
@@ -298,23 +300,40 @@ class AsyncMock(Mock):
298
300
299
301
class patch :
300
302
"""
301
- patch() acts as a function decorator, class decorator or a context manager.
302
- Inside the body of the function or with statement, the target is patched
303
- with a new object. When the function/with statement exits the patch is
304
- undone.
303
+ patch() acts as a function decorator or a context manager. Inside the body
304
+ of the function or with statement, the target is patched with a new object.
305
+ When the function/with statement exits the patch is undone.
305
306
"""
306
307
307
- def __init__ (self , target , new ):
308
+ def __init__ (self , target , new = None ):
308
309
"""
309
310
Create a new patch object that will replace the target with new.
311
+
312
+ If the target is a string, it should be in the form
313
+ "module.submodule.attribute" or "module.submodule:Class.attribute".
314
+
315
+ If no new object is provided, a new Mock object is created.
310
316
"""
311
317
self .target = target
312
- self .new = new
318
+ self .new = new or Mock ()
319
+
320
+ def __call__ (self , func , * args , ** kwargs ):
321
+ """
322
+ Decorate a function with the patch object.
323
+ """
324
+
325
+ def wrapper (* args , ** kwargs ):
326
+ with self (self .target , self .new ):
327
+ return func (* args , ** kwargs )
328
+
329
+ return wrapper
313
330
314
- def __enter__ (self ):
331
+ def __enter__ (self , target , new ):
315
332
"""
316
333
Replace the target with new.
317
334
"""
335
+ self .target = resolve_target (self .target )
336
+ self .new = new
318
337
self ._old = getattr (self .target , self .new .__name__ , None )
319
338
setattr (self .target , self .new .__name__ , self .new )
320
339
return self .new
@@ -325,3 +344,50 @@ def __exit__(self, exc_type, exc_value, traceback):
325
344
"""
326
345
setattr (self .target , self .new .__name__ , self ._old )
327
346
return False
347
+
348
+
349
+ def resolve_target (target ):
350
+ """
351
+ Return the target object. If the target is a string, search for the module
352
+ and attribute and return the attribute. Otherwise, return the target as is.
353
+
354
+ The target as a string should be in the form "module.submodule.attribute"
355
+ or "module.submodule:Class.attribute". This function imports the module and
356
+ returns the attribute.
357
+
358
+ "Inspired by" pkgutil.resolve_name in the CPython standard library (but
359
+ much simpler/naive).
360
+
361
+ Will raise an ImportError if the target module cannot be resolved or an
362
+ AttributeError if the attribute cannot be found.
363
+ """
364
+ if not isinstance (target , str ):
365
+ return target
366
+ if ":" in target :
367
+ # There is a colon - a one-step import is all that's needed.
368
+ module_name , attribute = target .split (":" )
369
+ module = __import__ (module_name )
370
+ parts = attribute .split ("." )
371
+ else :
372
+ # No colon - have to iterate to find the package boundary.
373
+ parts = target .split ("." )
374
+ module_name = parts .pop (0 )
375
+ # The first part of the target must be a module name.
376
+ module = __import__ (module_name )
377
+ while parts :
378
+ # Traverse the parts of the target to find the package boundary.
379
+ p = parts .pop (0 )
380
+ new_module_name = f"{ module_name } .{ p } "
381
+ try :
382
+ module = __import__ (new_module_name )
383
+ parts .pop (0 )
384
+ module_name = new_module_name
385
+ except ImportError :
386
+ break
387
+ # If we get here, module is the module object we're interested in and
388
+ # already imported. The parts list contains the remaining parts of the
389
+ # target to be traversed within the module (or an empty list).
390
+ result = module
391
+ for part in parts :
392
+ result = getattr (result , part )
393
+ return result
0 commit comments