|
1 | 1 | # Copyright (c) Jupyter Development Team.
|
2 | 2 | # Distributed under the terms of the Modified BSD License.
|
3 | 3 |
|
| 4 | +import asyncio |
| 5 | +import atexit |
4 | 6 | import errno
|
5 | 7 | import inspect
|
6 | 8 | import os
|
7 | 9 | import sys
|
| 10 | +import threading |
8 | 11 | import warnings
|
9 | 12 | from pathlib import Path
|
| 13 | +from typing import Any, Awaitable, Callable, Optional, TypeVar, Union |
10 | 14 |
|
11 | 15 |
|
12 | 16 | def ensure_dir_exists(path, mode=0o777):
|
@@ -81,3 +85,96 @@ def deprecation(message, internal="jupyter_core/"):
|
81 | 85 |
|
82 | 86 | # The call to .warn adds one frame, so bump the stacklevel up by one
|
83 | 87 | warnings.warn(message, DeprecationWarning, stacklevel=stacklevel + 1)
|
| 88 | + |
| 89 | + |
| 90 | +T = TypeVar("T") |
| 91 | + |
| 92 | + |
| 93 | +class _TaskRunner: |
| 94 | + """A task runner that runs an asyncio event loop on a background thread.""" |
| 95 | + |
| 96 | + def __init__(self): |
| 97 | + self.__io_loop: Optional[asyncio.AbstractEventLoop] = None |
| 98 | + self.__runner_thread: Optional[threading.Thread] = None |
| 99 | + self.__lock = threading.Lock() |
| 100 | + atexit.register(self._close) |
| 101 | + |
| 102 | + def _close(self): |
| 103 | + if self.__io_loop: |
| 104 | + self.__io_loop.stop() |
| 105 | + |
| 106 | + def _runner(self): |
| 107 | + loop = self.__io_loop |
| 108 | + assert loop is not None |
| 109 | + try: |
| 110 | + loop.run_forever() |
| 111 | + finally: |
| 112 | + loop.close() |
| 113 | + |
| 114 | + def run(self, coro): |
| 115 | + """Synchronously run a coroutine on a background thread.""" |
| 116 | + with self.__lock: |
| 117 | + name = f"{threading.current_thread().name} - runner" |
| 118 | + if self.__io_loop is None: |
| 119 | + self.__io_loop = asyncio.new_event_loop() |
| 120 | + self.__runner_thread = threading.Thread(target=self._runner, daemon=True, name=name) |
| 121 | + self.__runner_thread.start() |
| 122 | + fut = asyncio.run_coroutine_threadsafe(coro, self.__io_loop) |
| 123 | + return fut.result(None) |
| 124 | + |
| 125 | + |
| 126 | +_runner_map = {} |
| 127 | +_loop_map = {} |
| 128 | + |
| 129 | + |
| 130 | +def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]: |
| 131 | + """Runs a coroutine and blocks until it has executed. |
| 132 | +
|
| 133 | + Parameters |
| 134 | + ---------- |
| 135 | + coro : coroutine |
| 136 | + The coroutine to be executed. |
| 137 | + Returns |
| 138 | + ------- |
| 139 | + result : |
| 140 | + Whatever the coroutine returns. |
| 141 | + """ |
| 142 | + |
| 143 | + def wrapped(*args, **kwargs): |
| 144 | + name = threading.current_thread().name |
| 145 | + inner = coro(*args, **kwargs) |
| 146 | + try: |
| 147 | + # If a loop is currently running in this thread, |
| 148 | + # use a task runner. |
| 149 | + asyncio.get_running_loop() |
| 150 | + if name not in _runner_map: |
| 151 | + _runner_map[name] = _TaskRunner() |
| 152 | + return _runner_map[name].run(inner) |
| 153 | + except RuntimeError: |
| 154 | + pass |
| 155 | + |
| 156 | + # Run the loop for this thread. |
| 157 | + if name not in _loop_map: |
| 158 | + _loop_map[name] = asyncio.new_event_loop() |
| 159 | + loop = _loop_map[name] |
| 160 | + return loop.run_until_complete(inner) |
| 161 | + |
| 162 | + wrapped.__doc__ = coro.__doc__ |
| 163 | + return wrapped |
| 164 | + |
| 165 | + |
| 166 | +async def ensure_async(obj: Union[Awaitable[Any], Any]) -> Any: |
| 167 | + """Convert a non-awaitable object to a coroutine if needed, |
| 168 | + and await it if it was not already awaited. |
| 169 | + """ |
| 170 | + if inspect.isawaitable(obj): |
| 171 | + try: |
| 172 | + result = await obj |
| 173 | + except RuntimeError as e: |
| 174 | + if str(e) == "cannot reuse already awaited coroutine": |
| 175 | + # obj is already the coroutine's result |
| 176 | + return obj |
| 177 | + raise |
| 178 | + return result |
| 179 | + # obj doesn't need to be awaited |
| 180 | + return obj |
0 commit comments