|
20 | 20 | TypeVar,
|
21 | 21 | )
|
22 | 22 |
|
23 |
| -from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared |
| 23 | +from monarch._rust_bindings.monarch_hyperactor.pytokio import ( |
| 24 | + is_tokio_thread, |
| 25 | + PythonTask, |
| 26 | + Shared, |
| 27 | +) |
24 | 28 |
|
25 | 29 | from typing_extensions import deprecated, Self
|
26 | 30 |
|
@@ -79,7 +83,11 @@ class _Asyncio(NamedTuple):
|
79 | 83 | fut: asyncio.Future
|
80 | 84 |
|
81 | 85 |
|
82 |
| -_Status = _Unawaited | _Complete | _Exception | _Asyncio |
| 86 | +class _Tokio(NamedTuple): |
| 87 | + shared: Shared |
| 88 | + |
| 89 | + |
| 90 | +_Status = _Unawaited | _Complete | _Exception | _Asyncio | _Tokio |
83 | 91 |
|
84 | 92 |
|
85 | 93 | class Future(Generic[R]):
|
@@ -108,31 +116,60 @@ def get(self, timeout: Optional[float] = None) -> R:
|
108 | 116 | return cast("R", value)
|
109 | 117 | case _Exception(exe=exe):
|
110 | 118 | raise exe
|
| 119 | + case _Tokio(_): |
| 120 | + raise ValueError( |
| 121 | + "already converted into a pytokio.Shared object, use 'await' from a PythonTask coroutine to get the value." |
| 122 | + ) |
111 | 123 | case _:
|
112 | 124 | raise RuntimeError("unknown status")
|
113 | 125 |
|
114 | 126 | def __await__(self) -> Generator[Any, Any, R]:
|
115 |
| - match self._status: |
116 |
| - case _Unawaited(coro=coro): |
117 |
| - loop = asyncio.get_running_loop() |
118 |
| - fut = loop.create_future() |
119 |
| - self._status = _Asyncio(fut) |
120 |
| - |
121 |
| - async def mark_complete(): |
122 |
| - try: |
123 |
| - func, value = fut.set_result, await coro |
124 |
| - except Exception as e: |
125 |
| - func, value = fut.set_exception, e |
126 |
| - loop.call_soon_threadsafe(func, value) |
127 |
| - |
128 |
| - PythonTask.from_coroutine(mark_complete()).spawn() |
129 |
| - return fut.__await__() |
130 |
| - case _Asyncio(fut=fut): |
131 |
| - return fut.__await__() |
132 |
| - case _: |
133 |
| - raise ValueError( |
134 |
| - "already converted into a synchronous future, use 'get' to get the value." |
135 |
| - ) |
| 127 | + if asyncio._get_running_loop() is not None: |
| 128 | + match self._status: |
| 129 | + case _Unawaited(coro=coro): |
| 130 | + loop = asyncio.get_running_loop() |
| 131 | + fut = loop.create_future() |
| 132 | + self._status = _Asyncio(fut) |
| 133 | + |
| 134 | + async def mark_complete(): |
| 135 | + try: |
| 136 | + func, value = fut.set_result, await coro |
| 137 | + except Exception as e: |
| 138 | + func, value = fut.set_exception, e |
| 139 | + loop.call_soon_threadsafe(func, value) |
| 140 | + |
| 141 | + PythonTask.from_coroutine(mark_complete()).spawn() |
| 142 | + return fut.__await__() |
| 143 | + case _Asyncio(fut=fut): |
| 144 | + return fut.__await__() |
| 145 | + case _Tokio(_): |
| 146 | + raise ValueError( |
| 147 | + "already converted into a tokio future, but being awaited from the asyncio loop." |
| 148 | + ) |
| 149 | + case _: |
| 150 | + raise ValueError( |
| 151 | + "already converted into a synchronous future, use 'get' to get the value." |
| 152 | + ) |
| 153 | + elif is_tokio_thread(): |
| 154 | + match self._status: |
| 155 | + case _Unawaited(coro=coro): |
| 156 | + shared = coro.spawn() |
| 157 | + self._status = _Tokio(shared) |
| 158 | + return shared.__await__() |
| 159 | + case _Tokio(shared=shared): |
| 160 | + return shared.__await__() |
| 161 | + case _Asyncio(_): |
| 162 | + raise ValueError( |
| 163 | + "already converted into asyncio future, but being awaited from the tokio loop." |
| 164 | + ) |
| 165 | + case _: |
| 166 | + raise ValueError( |
| 167 | + "already converted into a synchronous future, use 'get' to get the value." |
| 168 | + ) |
| 169 | + else: |
| 170 | + raise ValueError( |
| 171 | + "__await__ with no active event loop (either asyncio or tokio)" |
| 172 | + ) |
136 | 173 |
|
137 | 174 | # compatibility with old tensor engine Future objects
|
138 | 175 | # hopefully we do not need done(), add_callback because
|
|
0 commit comments