diff --git a/asyncio_atexit.py b/asyncio_atexit.py index 1615b1e..e0ee13b 100644 --- a/asyncio_atexit.py +++ b/asyncio_atexit.py @@ -3,6 +3,7 @@ """ import asyncio +import atexit import inspect import sys import weakref @@ -33,8 +34,15 @@ def __init__(self, loop): # Hold a regular reference _on the object_, in those cases self._close_ref = lambda: loop._atexit_orig_close self.callbacks = [] + loop_ref = weakref.ref(loop) + self._atexit_handle = partial(_atexit_close_loop, loop_ref) + atexit.register(self._atexit_handle) + + def _unregister(self): + atexit.unregister(self._atexit_handle) def close(self): + self._unregister() return self._close_ref()() @@ -74,6 +82,10 @@ def unregister(callback, *, loop=None): except ValueError: break + if not entry.callbacks: + # no callbacks registered, unregister the atexit close callback as well + entry._unregister() + def _get_entry(loop=None): """Get the registry entry for an event loop""" @@ -120,3 +132,19 @@ def _asyncio_atexit_close(loop): loop.run_until_complete(_run_asyncio_atexits(loop, entry.callbacks)) entry.callbacks[:] = [] return entry.close() + + +def _atexit_close_loop(loop_ref, *args): + """ + atexit callback to call loop.close + + Register loop.close with atexit, + so we are more confident that loop.close will actually be called. + """ + loop = loop_ref() + if loop is None: + return + try: + loop.close() + except Exception as e: + print(f"Exception in asyncio event_loop.close: {e}", file=sys.stderr) diff --git a/test_asyncio_atexit.py b/test_asyncio_atexit.py index cb35def..a3b1544 100644 --- a/test_asyncio_atexit.py +++ b/test_asyncio_atexit.py @@ -92,3 +92,22 @@ async def test(): asyncio_run(test()) assert sync_called + + +def test_atexit_hook(policy): + loop = asyncio.new_event_loop() + sync_called = False + + def sync_cb(): + nonlocal sync_called + sync_called = True + + async def test(): + asyncio_atexit.register(sync_cb) + + loop.run_until_complete(test()) + assert loop in asyncio_atexit._registry + # can't easily test true atexit invocations + entry = asyncio_atexit._get_entry(loop) + entry._atexit_handle(*sys.exc_info()) + assert loop._closed