Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions asyncio_atexit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import asyncio
import atexit
import inspect
import sys
import weakref
Expand Down Expand Up @@ -33,8 +34,15 @@ def __init__(self, loop):
# Hold a regular reference _on the object_, in those cases
self._close_ref = lambda: loop._atexit_orig_close
self.callbacks = []
loop_ref = weakref.ref(loop)
self._atexit_handle = partial(_atexit_close_loop, loop_ref)
atexit.register(self._atexit_handle)

def _unregister(self):
atexit.unregister(self._atexit_handle)

def close(self):
self._unregister()
return self._close_ref()()


Expand Down Expand Up @@ -74,6 +82,10 @@ def unregister(callback, *, loop=None):
except ValueError:
break

if not entry.callbacks:
# no callbacks registered, unregister the atexit close callback as well
entry._unregister()


def _get_entry(loop=None):
"""Get the registry entry for an event loop"""
Expand Down Expand Up @@ -120,3 +132,19 @@ def _asyncio_atexit_close(loop):
loop.run_until_complete(_run_asyncio_atexits(loop, entry.callbacks))
entry.callbacks[:] = []
return entry.close()


def _atexit_close_loop(loop_ref, *args):
"""
atexit callback to call loop.close

Register loop.close with atexit,
so we are more confident that loop.close will actually be called.
"""
loop = loop_ref()
if loop is None:
return
try:
loop.close()
except Exception as e:
print(f"Exception in asyncio event_loop.close: {e}", file=sys.stderr)
19 changes: 19 additions & 0 deletions test_asyncio_atexit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,22 @@ async def test():
asyncio_run(test())

assert sync_called


def test_atexit_hook(policy):
loop = asyncio.new_event_loop()
sync_called = False

def sync_cb():
nonlocal sync_called
sync_called = True

async def test():
asyncio_atexit.register(sync_cb)

loop.run_until_complete(test())
assert loop in asyncio_atexit._registry
# can't easily test true atexit invocations
entry = asyncio_atexit._get_entry(loop)
entry._atexit_handle(*sys.exc_info())
assert loop._closed