Skip to content

Commit 694951b

Browse files
zdevitofacebook-github-bot
authored andcommitted
Polymorphic Future await? (#757)
Summary: Pull Request resolved: #757 Previously if on the tokio event loop we forced our internal uses to directly await the PythonTask. This leads to a few places where we end up with a Future but want to await it from a coro also on the tokio event loop. We have the option of making `__await__` work in either case: on a asyncio loop, create a python future, on a tokio loop spawn/await. If we want consumers to create Future objects themselves we probably want this. If we want tokio event loop things to be an implementation detail of monarch, we probably do not want this. ghstack-source-id: 300750788 Reviewed By: allenwang28 Differential Revision: D79596925 fbshipit-source-id: 6ccde7e9762fce85b83e169a6157be1164e7ecff
1 parent 6e3851c commit 694951b

File tree

4 files changed

+79
-25
lines changed

4 files changed

+79
-25
lines changed

monarch_hyperactor/src/pytokio.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,20 @@ impl PyShared {
282282
}
283283
}
284284

285+
#[pyfunction]
286+
fn is_tokio_thread() -> bool {
287+
tokio::runtime::Handle::try_current().is_ok()
288+
}
289+
285290
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
286291
hyperactor_mod.add_class::<PyPythonTask>()?;
287292
hyperactor_mod.add_class::<PyShared>()?;
293+
let f = wrap_pyfunction!(is_tokio_thread, hyperactor_mod)?;
294+
f.setattr(
295+
"__module__",
296+
"monarch._rust_bindings.monarch_hyperactor.pytokio",
297+
)?;
298+
hyperactor_mod.add_function(f)?;
299+
288300
Ok(())
289301
}

python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,9 @@ class Shared(Generic[T]):
7373
Create a one-use Task that awaits on this if you want to use other PythonTask apis like with_timeout.
7474
"""
7575
...
76+
77+
def is_tokio_thread() -> bool:
78+
"""
79+
Returns true if the current thread is a tokio worker thread (and block_on will fail).
80+
"""
81+
...

python/monarch/_src/actor/future.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
TypeVar,
2121
)
2222

23-
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
23+
from monarch._rust_bindings.monarch_hyperactor.pytokio import (
24+
is_tokio_thread,
25+
PythonTask,
26+
Shared,
27+
)
2428

2529
from typing_extensions import deprecated, Self
2630

@@ -79,7 +83,11 @@ class _Asyncio(NamedTuple):
7983
fut: asyncio.Future
8084

8185

82-
_Status = _Unawaited | _Complete | _Exception | _Asyncio
86+
class _Tokio(NamedTuple):
87+
shared: Shared
88+
89+
90+
_Status = _Unawaited | _Complete | _Exception | _Asyncio | _Tokio
8391

8492

8593
class Future(Generic[R]):
@@ -108,31 +116,60 @@ def get(self, timeout: Optional[float] = None) -> R:
108116
return cast("R", value)
109117
case _Exception(exe=exe):
110118
raise exe
119+
case _Tokio(_):
120+
raise ValueError(
121+
"already converted into a pytokio.Shared object, use 'await' from a PythonTask coroutine to get the value."
122+
)
111123
case _:
112124
raise RuntimeError("unknown status")
113125

114126
def __await__(self) -> Generator[Any, Any, R]:
115-
match self._status:
116-
case _Unawaited(coro=coro):
117-
loop = asyncio.get_running_loop()
118-
fut = loop.create_future()
119-
self._status = _Asyncio(fut)
120-
121-
async def mark_complete():
122-
try:
123-
func, value = fut.set_result, await coro
124-
except Exception as e:
125-
func, value = fut.set_exception, e
126-
loop.call_soon_threadsafe(func, value)
127-
128-
PythonTask.from_coroutine(mark_complete()).spawn()
129-
return fut.__await__()
130-
case _Asyncio(fut=fut):
131-
return fut.__await__()
132-
case _:
133-
raise ValueError(
134-
"already converted into a synchronous future, use 'get' to get the value."
135-
)
127+
if asyncio._get_running_loop() is not None:
128+
match self._status:
129+
case _Unawaited(coro=coro):
130+
loop = asyncio.get_running_loop()
131+
fut = loop.create_future()
132+
self._status = _Asyncio(fut)
133+
134+
async def mark_complete():
135+
try:
136+
func, value = fut.set_result, await coro
137+
except Exception as e:
138+
func, value = fut.set_exception, e
139+
loop.call_soon_threadsafe(func, value)
140+
141+
PythonTask.from_coroutine(mark_complete()).spawn()
142+
return fut.__await__()
143+
case _Asyncio(fut=fut):
144+
return fut.__await__()
145+
case _Tokio(_):
146+
raise ValueError(
147+
"already converted into a tokio future, but being awaited from the asyncio loop."
148+
)
149+
case _:
150+
raise ValueError(
151+
"already converted into a synchronous future, use 'get' to get the value."
152+
)
153+
elif is_tokio_thread():
154+
match self._status:
155+
case _Unawaited(coro=coro):
156+
shared = coro.spawn()
157+
self._status = _Tokio(shared)
158+
return shared.__await__()
159+
case _Tokio(shared=shared):
160+
return shared.__await__()
161+
case _Asyncio(_):
162+
raise ValueError(
163+
"already converted into asyncio future, but being awaited from the tokio loop."
164+
)
165+
case _:
166+
raise ValueError(
167+
"already converted into a synchronous future, use 'get' to get the value."
168+
)
169+
else:
170+
raise ValueError(
171+
"__await__ with no active event loop (either asyncio or tokio)"
172+
)
136173

137174
# compatibility with old tensor engine Future objects
138175
# hopefully we do not need done(), add_callback because

python/monarch/_src/actor/proc_mesh.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,7 @@ async def _init_manager_actors_coro(
197197
setup_actor = await self._spawn_nonblocking_on(
198198
proc_mesh, "setup", SetupActor, setup
199199
)
200-
# pyre-ignore
201-
await setup_actor.setup.call()._status.coro
200+
await setup_actor.setup.call()
202201

203202
return proc_mesh
204203

0 commit comments

Comments
 (0)