@@ -47,6 +47,8 @@ def __init__(self, config):
4747 self .MagicMock = mock_module .MagicMock
4848 self .NonCallableMock = mock_module .NonCallableMock
4949 self .PropertyMock = mock_module .PropertyMock
50+ if hasattr (mock_module , "AsyncMock" ):
51+ self .AsyncMock = mock_module .AsyncMock
5052 self .call = mock_module .call
5153 self .ANY = mock_module .ANY
5254 self .DEFAULT = mock_module .DEFAULT
@@ -275,6 +277,41 @@ def wrap_assert_called(*args, **kwargs):
275277 assert_wrapper (_mock_module_originals ["assert_called" ], * args , ** kwargs )
276278
277279
280+ def wrap_assert_not_awaited (* args , ** kwargs ):
281+ __tracebackhide__ = True
282+ assert_wrapper (_mock_module_originals ["assert_not_awaited" ], * args , ** kwargs )
283+
284+
285+ def wrap_assert_awaited_with (* args , ** kwargs ):
286+ __tracebackhide__ = True
287+ assert_wrapper (_mock_module_originals ["assert_awaited_with" ], * args , ** kwargs )
288+
289+
290+ def wrap_assert_awaited_once (* args , ** kwargs ):
291+ __tracebackhide__ = True
292+ assert_wrapper (_mock_module_originals ["assert_awaited_once" ], * args , ** kwargs )
293+
294+
295+ def wrap_assert_awaited_once_with (* args , ** kwargs ):
296+ __tracebackhide__ = True
297+ assert_wrapper (_mock_module_originals ["assert_awaited_once_with" ], * args , ** kwargs )
298+
299+
300+ def wrap_assert_has_awaits (* args , ** kwargs ):
301+ __tracebackhide__ = True
302+ assert_wrapper (_mock_module_originals ["assert_has_awaits" ], * args , ** kwargs )
303+
304+
305+ def wrap_assert_any_await (* args , ** kwargs ):
306+ __tracebackhide__ = True
307+ assert_wrapper (_mock_module_originals ["assert_any_await" ], * args , ** kwargs )
308+
309+
310+ def wrap_assert_awaited (* args , ** kwargs ):
311+ __tracebackhide__ = True
312+ assert_wrapper (_mock_module_originals ["assert_awaited" ], * args , ** kwargs )
313+
314+
278315def wrap_assert_methods (config ):
279316 """
280317 Wrap assert methods of mock module so we can hide their traceback and
@@ -305,6 +342,25 @@ def wrap_assert_methods(config):
305342 patcher .start ()
306343 _mock_module_patches .append (patcher )
307344
345+ async_wrappers = {
346+ "assert_awaited" : wrap_assert_awaited ,
347+ "assert_awaited_once" : wrap_assert_awaited_once ,
348+ "assert_awaited_with" : wrap_assert_awaited_with ,
349+ "assert_awaited_once_with" : wrap_assert_awaited_once_with ,
350+ "assert_any_await" : wrap_assert_any_await ,
351+ "assert_has_awaits" : wrap_assert_has_awaits ,
352+ "assert_not_awaited" : wrap_assert_not_awaited ,
353+ }
354+ for method , wrapper in async_wrappers .items ():
355+ try :
356+ original = getattr (mock_module .AsyncMock , method )
357+ except AttributeError : # pragma: no cover
358+ continue
359+ _mock_module_originals [method ] = original
360+ patcher = mock_module .patch .object (mock_module .AsyncMock , method , wrapper )
361+ patcher .start ()
362+ _mock_module_patches .append (patcher )
363+
308364 if hasattr (config , "add_cleanup" ):
309365 add_cleanup = config .add_cleanup
310366 else :
0 commit comments