Skip to content

Commit 44f85b5

Browse files
committed
Move casting protocol into rust
Pull Request resolved: #851 This moves the logic for deferring sends into Rust to avoid slowing down actor sends. Increases send throughput ~30% from the python approach. It is just a little faster than what came before the python approach because this also eliminates a layer of python wrappers for dispatching sends. I suspect there is tons more to to be gained removing more layers of python between the user calling an endpoint and issuing a send. ghstack-source-id: 303184134 Differential Revision: [D80037834](https://our.internmc.facebook.com/intern/diff/D80037834/)
1 parent 364a22c commit 44f85b5

File tree

12 files changed

+516
-583
lines changed

12 files changed

+516
-583
lines changed

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 338 additions & 191 deletions
Large diffs are not rendered by default.

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use hyperactor_mesh::shared_cell::SharedCellPool;
2929
use hyperactor_mesh::shared_cell::SharedCellRef;
3030
use monarch_types::PickledPyObject;
3131
use ndslice::Shape;
32+
use pyo3::IntoPyObjectExt;
3233
use pyo3::exceptions::PyException;
3334
use pyo3::exceptions::PyRuntimeError;
3435
use pyo3::prelude::*;
@@ -38,10 +39,13 @@ use tokio::sync::Mutex;
3839
use tokio::sync::mpsc;
3940

4041
use crate::actor_mesh::PythonActorMesh;
42+
use crate::actor_mesh::PythonActorMeshImpl;
4143
use crate::alloc::PyAlloc;
4244
use crate::mailbox::PyMailbox;
4345
use crate::pytokio::PyPythonTask;
46+
use crate::pytokio::PyShared;
4447
use crate::pytokio::PythonTask;
48+
use crate::runtime::get_tokio_runtime;
4549
use crate::shape::PyShape;
4650
use crate::supervision::SupervisionError;
4751
use crate::supervision::Unhealthy;
@@ -278,21 +282,63 @@ impl PyProcMesh {
278282
let pickled_type = PickledPyObject::pickle(actor.as_any())?;
279283
let proc_mesh = self.try_inner()?;
280284
let keepalive = self.keepalive.clone();
281-
PyPythonTask::new(async move {
285+
let meshimpl = async move {
282286
ensure_mesh_healthy(&unhealthy_event).await?;
287+
let mailbox = proc_mesh.client().clone();
288+
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
289+
let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap();
290+
let im = PythonActorMeshImpl::new(
291+
actor_mesh,
292+
PyMailbox { inner: mailbox },
293+
keepalive,
294+
actor_events,
295+
);
296+
Ok(PythonActorMesh::from_impl(im))
297+
};
298+
PyPythonTask::new(meshimpl)
299+
}
283300

301+
#[staticmethod]
302+
fn spawn_async(
303+
proc_mesh: &mut PyShared,
304+
name: String,
305+
actor: Py<PyType>,
306+
emulated: bool,
307+
) -> PyResult<PyObject> {
308+
let task = proc_mesh.task()?.take_task()?;
309+
let meshimpl = async move {
310+
let proc_mesh = task.await?;
311+
let (proc_mesh, pickled_type, unhealthy_event, keepalive) =
312+
Python::with_gil(|py| -> PyResult<_> {
313+
let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
314+
let slf = slf.borrow();
315+
let unhealthy_event = Arc::clone(&slf.unhealthy_event);
316+
let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
317+
let proc_mesh = slf.try_inner()?;
318+
let keepalive = slf.keepalive.clone();
319+
Ok((proc_mesh, pickled_type, unhealthy_event, keepalive))
320+
})?;
321+
ensure_mesh_healthy(&unhealthy_event).await?;
284322
let mailbox = proc_mesh.client().clone();
285323
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
286324
let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap();
287-
Ok(PythonActorMesh::monitored(
325+
Ok(PythonActorMeshImpl::new(
288326
actor_mesh,
289327
PyMailbox { inner: mailbox },
290328
keepalive,
291329
actor_events,
292330
))
293-
})
331+
};
332+
if emulated {
333+
// we give up on doing mesh spawn async for the emulated old version
334+
// it is too complicated to make both work.
335+
let r = get_tokio_runtime().block_on(meshimpl)?;
336+
Python::with_gil(|py| r.into_py_any(py))
337+
} else {
338+
let r = PythonActorMesh::new(meshimpl);
339+
Python::with_gil(|py| r.into_py_any(py))
340+
}
294341
}
295-
296342
// User can call this to monitor the proc mesh events. This will override
297343
// the default monitor that exits the client on process crash, so user can
298344
// handle the process crash in their own way.

monarch_hyperactor/src/pytokio.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
use std::error::Error;
1010
use std::future::Future;
11+
use std::ops::Deref;
1112
use std::pin::Pin;
1213

1314
use hyperactor::clock::Clock;
@@ -19,6 +20,7 @@ use pyo3::exceptions::PyStopIteration;
1920
use pyo3::exceptions::PyTimeoutError;
2021
use pyo3::exceptions::PyValueError;
2122
use pyo3::prelude::*;
23+
use pyo3::types::PyType;
2224
use tokio::sync::Mutex;
2325
use tokio::sync::watch;
2426

@@ -119,7 +121,7 @@ where
119121
}
120122

121123
impl PyPythonTask {
122-
fn take_task(
124+
pub(crate) fn take_task(
123125
&mut self,
124126
) -> PyResult<Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>> {
125127
self.inner
@@ -158,7 +160,7 @@ impl PyPythonTask {
158160
signal_safe_block_on(py, task)?
159161
}
160162

161-
fn spawn(&mut self) -> PyResult<PyShared> {
163+
pub(crate) fn spawn(&mut self) -> PyResult<PyShared> {
162164
let (tx, rx) = watch::channel(None);
163165
let task = self.take_task()?;
164166
get_tokio_runtime().spawn(async move {
@@ -266,6 +268,11 @@ impl PyPythonTask {
266268
result.map(|r| (r, index))
267269
})
268270
}
271+
272+
#[classmethod]
273+
fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: PyObject) -> PyObject {
274+
cls.clone().unbind().into()
275+
}
269276
}
270277

271278
#[pyclass(
@@ -277,7 +284,7 @@ pub struct PyShared {
277284
}
278285
#[pymethods]
279286
impl PyShared {
280-
fn task(&mut self) -> PyResult<PyPythonTask> {
287+
pub(crate) fn task(&mut self) -> PyResult<PyPythonTask> {
281288
// watch channels start unchanged, and when a value is sent to them signal
282289
// the receivers `changed` future.
283290
// By cloning the rx before awaiting it,
@@ -306,6 +313,11 @@ impl PyShared {
306313
drop(slf);
307314
signal_safe_block_on(py, task)?
308315
}
316+
317+
#[classmethod]
318+
fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: PyObject) -> PyObject {
319+
cls.clone().unbind().into()
320+
}
309321
}
310322

311323
#[pyfunction]

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 18 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import AsyncIterator, final, NoReturn
9+
from typing import AsyncIterator, final, NoReturn, Optional, Protocol
1010

1111
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
1212
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
@@ -15,127 +15,52 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
1515
PortReceiver,
1616
)
1717
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
18-
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
18+
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
1919
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
2020
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
2121
from typing_extensions import Self
2222

23-
@final
24-
class PythonActorMeshRef:
23+
class ActorMeshProtocol(Protocol):
2524
"""
26-
A reference to a remote actor mesh over which PythonMessages can be sent.
25+
Protocol defining the common interface for actor mesh, mesh ref and _ActorMeshRefImpl.
2726
"""
2827

2928
def cast(
30-
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
31-
) -> None:
32-
"""Cast a message to the selected actors in the mesh."""
33-
...
34-
35-
def slice(self, **kwargs: int | slice[int | None, int | None, int | None]) -> Self:
36-
"""
37-
See PythonActorMeshRef.slice for documentation.
38-
"""
39-
...
40-
41-
def new_with_shape(self, shape: Shape) -> PythonActorMeshRef:
42-
"""
43-
Return a new mesh ref with the given sliced shape. If the provided shape
44-
is not a valid slice of the current shape, an exception will be raised.
45-
"""
46-
...
47-
48-
@property
49-
def shape(self) -> Shape:
50-
"""
51-
The Shape object that describes how the rank of an actor
52-
retrieved with get corresponds to coordinates in the
53-
mesh.
54-
"""
55-
...
29+
self,
30+
message: PythonMessage,
31+
selection: str,
32+
mailbox: Mailbox,
33+
) -> None: ...
34+
def new_with_shape(self, shape: Shape) -> Self: ...
35+
def supervision_event(self) -> "Optional[Shared[Exception]]": ...
36+
def stop(self) -> PythonTask[None]: ...
37+
def initialized(self) -> PythonTask[None]: ...
5638

5739
@final
58-
class PythonActorMesh:
59-
def bind(self) -> PythonActorMeshRef:
60-
"""
61-
Bind this actor mesh. The returned mesh ref can be used to reach the
62-
mesh remotely.
63-
"""
64-
...
65-
66-
def cast(
67-
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
68-
) -> None:
69-
"""
70-
Cast a message to the selected actors in the mesh.
71-
"""
72-
...
73-
74-
def slice(
75-
self, **kwargs: int | slice[int | None, int | None, int | None]
76-
) -> PythonActorMeshRef:
77-
"""
78-
Slice the mesh into a new mesh ref with the given selection. The reason
79-
it returns a mesh ref, rather than the mesh object itself, is because
80-
sliced mesh is a view of the original mesh, and does not own the mesh's
81-
resources.
82-
83-
Arguments:
84-
- `kwargs`: argument name is the label, and argument value is how to
85-
slice the mesh along the dimension of that label.
86-
"""
87-
...
88-
89-
def new_with_shape(self, shape: Shape) -> PythonActorMeshRef:
90-
"""
91-
Return a new mesh ref with the given sliced shape. If the provided shape
92-
is not a valid slice of the current shape, an exception will be raised.
93-
"""
94-
...
40+
class PythonActorMesh(ActorMeshProtocol):
41+
pass
9542

43+
class PythonActorMeshImpl:
9644
def get_supervision_event(self) -> ActorSupervisionEvent | None:
9745
"""
9846
Returns supervision event if there is any.
9947
"""
10048
...
10149

102-
def supervision_event(self) -> PythonTask[Exception]:
103-
"""
104-
Completes with an exception when there is a supervision error.
105-
"""
106-
...
107-
10850
def get(self, rank: int) -> ActorId | None:
10951
"""
11052
Get the actor id for the actor at the given rank.
11153
"""
11254
...
11355

114-
@property
115-
def client(self) -> Mailbox:
116-
"""
117-
A client that can be used to communicate with individual
118-
actors in the mesh, and also to create ports that can be
119-
broadcast across the mesh)
120-
"""
121-
...
122-
123-
@property
124-
def shape(self) -> Shape:
125-
"""
126-
The Shape object that describes how the rank of an actor
127-
retrieved with get corresponds to coordinates in the
128-
mesh.
129-
"""
130-
...
131-
132-
async def stop(self) -> None:
56+
def stop(self) -> PythonTask[None]:
13357
"""
13458
Stop all actors that are part of this mesh.
13559
Using this mesh after stop() is called will raise an Exception.
13660
"""
13761
...
13862

63+
def supervision_event(self) -> "Optional[Shared[Exception]]": ...
13964
@property
14065
def stopped(self) -> bool:
14166
"""

python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66

77
# pyre-strict
88

9-
from typing import AsyncIterator, final, Type
9+
from typing import AsyncIterator, final, Literal, overload, Type
1010

1111
from monarch._rust_bindings.monarch_hyperactor.actor import Actor
12-
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
12+
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import (
13+
PythonActorMesh,
14+
PythonActorMeshImpl,
15+
)
1316

1417
from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc
1518
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
16-
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
19+
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
1720

1821
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
1922

@@ -42,6 +45,10 @@ class ProcMesh:
4245
"""
4346
...
4447

48+
@staticmethod
49+
def spawn_async(
50+
proc_mesh: Shared[ProcMesh], name: str, actor: Type[Actor], emulated: bool
51+
) -> PythonActorMesh: ...
4552
async def monitor(self) -> ProcMeshMonitor:
4653
"""
4754
Returns a supervision monitor for this mesh.

0 commit comments

Comments
 (0)