Skip to content

Commit 51b8a12

Browse files
Add run_sync and ensure_async functions (#315)
1 parent f7e1f00 commit 51b8a12

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed

jupyter_core/tests/test_async.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Tests for async helper functions"""
2+
3+
# Copyright (c) Jupyter Development Team.
4+
# Distributed under the terms of the Modified BSD License.
5+
6+
import asyncio
7+
8+
from jupyter_core.utils import ensure_async, run_sync
9+
10+
11+
async def afunc():
12+
return "afunc"
13+
14+
15+
def func():
16+
return "func"
17+
18+
19+
sync_afunc = run_sync(afunc)
20+
21+
22+
def test_ensure_async():
23+
async def main():
24+
assert await ensure_async(afunc()) == "afunc"
25+
assert await ensure_async(func()) == "func"
26+
27+
asyncio.run(main())
28+
29+
30+
def test_run_sync():
31+
async def main():
32+
assert sync_afunc() == "afunc"
33+
34+
asyncio.run(main())

jupyter_core/utils/__init__.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# Copyright (c) Jupyter Development Team.
22
# Distributed under the terms of the Modified BSD License.
33

4+
import asyncio
5+
import atexit
46
import errno
57
import inspect
68
import os
79
import sys
10+
import threading
811
import warnings
912
from pathlib import Path
13+
from typing import Any, Awaitable, Callable, Optional, TypeVar, Union
1014

1115

1216
def ensure_dir_exists(path, mode=0o777):
@@ -81,3 +85,96 @@ def deprecation(message, internal="jupyter_core/"):
8185

8286
# The call to .warn adds one frame, so bump the stacklevel up by one
8387
warnings.warn(message, DeprecationWarning, stacklevel=stacklevel + 1)
88+
89+
90+
T = TypeVar("T")
91+
92+
93+
class _TaskRunner:
94+
"""A task runner that runs an asyncio event loop on a background thread."""
95+
96+
def __init__(self):
97+
self.__io_loop: Optional[asyncio.AbstractEventLoop] = None
98+
self.__runner_thread: Optional[threading.Thread] = None
99+
self.__lock = threading.Lock()
100+
atexit.register(self._close)
101+
102+
def _close(self):
103+
if self.__io_loop:
104+
self.__io_loop.stop()
105+
106+
def _runner(self):
107+
loop = self.__io_loop
108+
assert loop is not None
109+
try:
110+
loop.run_forever()
111+
finally:
112+
loop.close()
113+
114+
def run(self, coro):
115+
"""Synchronously run a coroutine on a background thread."""
116+
with self.__lock:
117+
name = f"{threading.current_thread().name} - runner"
118+
if self.__io_loop is None:
119+
self.__io_loop = asyncio.new_event_loop()
120+
self.__runner_thread = threading.Thread(target=self._runner, daemon=True, name=name)
121+
self.__runner_thread.start()
122+
fut = asyncio.run_coroutine_threadsafe(coro, self.__io_loop)
123+
return fut.result(None)
124+
125+
126+
_runner_map = {}
127+
_loop_map = {}
128+
129+
130+
def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]:
131+
"""Runs a coroutine and blocks until it has executed.
132+
133+
Parameters
134+
----------
135+
coro : coroutine
136+
The coroutine to be executed.
137+
Returns
138+
-------
139+
result :
140+
Whatever the coroutine returns.
141+
"""
142+
143+
def wrapped(*args, **kwargs):
144+
name = threading.current_thread().name
145+
inner = coro(*args, **kwargs)
146+
try:
147+
# If a loop is currently running in this thread,
148+
# use a task runner.
149+
asyncio.get_running_loop()
150+
if name not in _runner_map:
151+
_runner_map[name] = _TaskRunner()
152+
return _runner_map[name].run(inner)
153+
except RuntimeError:
154+
pass
155+
156+
# Run the loop for this thread.
157+
if name not in _loop_map:
158+
_loop_map[name] = asyncio.new_event_loop()
159+
loop = _loop_map[name]
160+
return loop.run_until_complete(inner)
161+
162+
wrapped.__doc__ = coro.__doc__
163+
return wrapped
164+
165+
166+
async def ensure_async(obj: Union[Awaitable[Any], Any]) -> Any:
167+
"""Convert a non-awaitable object to a coroutine if needed,
168+
and await it if it was not already awaited.
169+
"""
170+
if inspect.isawaitable(obj):
171+
try:
172+
result = await obj
173+
except RuntimeError as e:
174+
if str(e) == "cannot reuse already awaited coroutine":
175+
# obj is already the coroutine's result
176+
return obj
177+
raise
178+
return result
179+
# obj doesn't need to be awaited
180+
return obj

0 commit comments

Comments
 (0)