Skip to content

Commit d0d5203

Browse files
committed
Move casting protocol into rust
This moves the logic for deferring sends into Rust to avoid slowing down actor sends. Increases send throughput ~30% from the python approach. It is just a little faster than what came before the python approach because this also eliminates a layer of python wrappers for dispatching sends. I suspect there is tons more to to be gained removing more layers of python between the user calling an endpoint and issuing a send. Differential Revision: [D80037834](https://our.internmc.facebook.com/intern/diff/D80037834/) ghstack-source-id: 302731839 Pull Request resolved: #851
1 parent 67c170a commit d0d5203

File tree

12 files changed

+521
-582
lines changed

12 files changed

+521
-582
lines changed

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 335 additions & 189 deletions
Large diffs are not rendered by default.

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use hyperactor_mesh::shared_cell::SharedCellPool;
2929
use hyperactor_mesh::shared_cell::SharedCellRef;
3030
use monarch_types::PickledPyObject;
3131
use ndslice::Shape;
32+
use pyo3::IntoPyObjectExt;
3233
use pyo3::exceptions::PyException;
3334
use pyo3::exceptions::PyRuntimeError;
3435
use pyo3::prelude::*;
@@ -38,10 +39,13 @@ use tokio::sync::Mutex;
3839
use tokio::sync::mpsc;
3940

4041
use crate::actor_mesh::PythonActorMesh;
42+
use crate::actor_mesh::PythonActorMeshImpl;
4143
use crate::alloc::PyAlloc;
4244
use crate::mailbox::PyMailbox;
4345
use crate::pytokio::PyPythonTask;
46+
use crate::pytokio::PyShared;
4447
use crate::pytokio::PythonTask;
48+
use crate::runtime::get_tokio_runtime;
4549
use crate::shape::PyShape;
4650
use crate::supervision::SupervisionError;
4751
use crate::supervision::Unhealthy;
@@ -273,26 +277,76 @@ impl PyProcMesh {
273277
&self,
274278
name: String,
275279
actor: &Bound<'py, PyType>,
276-
) -> PyResult<PyPythonTask> {
280+
emulated: bool,
281+
) -> PyResult<PyObject> {
277282
let unhealthy_event = Arc::clone(&self.unhealthy_event);
278283
let pickled_type = PickledPyObject::pickle(actor.as_any())?;
279284
let proc_mesh = self.try_inner()?;
280285
let keepalive = self.keepalive.clone();
281-
PyPythonTask::new(async move {
286+
let meshimpl = async move {
282287
ensure_mesh_healthy(&unhealthy_event).await?;
283-
284288
let mailbox = proc_mesh.client().clone();
285289
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
286290
let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap();
287-
Ok(PythonActorMesh::monitored(
291+
Ok(PythonActorMeshImpl::new(
288292
actor_mesh,
289293
PyMailbox { inner: mailbox },
290294
keepalive,
291295
actor_events,
292296
))
293-
})
297+
};
298+
if emulated {
299+
// we give up on doing mesh spawn async for the emulated old version
300+
// it is too complicated to make both work.
301+
let r = get_tokio_runtime().block_on(meshimpl)?;
302+
Python::with_gil(|py| r.into_py_any(py))
303+
} else {
304+
let r = PythonActorMesh::new(meshimpl);
305+
Python::with_gil(|py| r.into_py_any(py))
306+
}
294307
}
295308

309+
#[staticmethod]
310+
fn spawn_async(
311+
proc_mesh: &mut PyShared,
312+
name: String,
313+
actor: Py<PyType>,
314+
emulated: bool,
315+
) -> PyResult<PyObject> {
316+
let task = proc_mesh.task()?.take_task()?;
317+
let meshimpl = async move {
318+
let proc_mesh = task.await?;
319+
let (proc_mesh, pickled_type, unhealthy_event, keepalive) =
320+
Python::with_gil(|py| -> PyResult<_> {
321+
let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
322+
let slf = slf.borrow();
323+
let unhealthy_event = Arc::clone(&slf.unhealthy_event);
324+
let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
325+
let proc_mesh = slf.try_inner()?;
326+
let keepalive = slf.keepalive.clone();
327+
Ok((proc_mesh, pickled_type, unhealthy_event, keepalive))
328+
})?;
329+
ensure_mesh_healthy(&unhealthy_event).await?;
330+
let mailbox = proc_mesh.client().clone();
331+
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
332+
let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap();
333+
Ok(PythonActorMeshImpl::new(
334+
actor_mesh,
335+
PyMailbox { inner: mailbox },
336+
keepalive,
337+
actor_events,
338+
))
339+
};
340+
if emulated {
341+
// we give up on doing mesh spawn async for the emulated old version
342+
// it is too complicated to make both work.
343+
let r = get_tokio_runtime().block_on(meshimpl)?;
344+
Python::with_gil(|py| r.into_py_any(py))
345+
} else {
346+
let r = PythonActorMesh::new(meshimpl);
347+
Python::with_gil(|py| r.into_py_any(py))
348+
}
349+
}
296350
// User can call this to monitor the proc mesh events. This will override
297351
// the default monitor that exits the client on process crash, so user can
298352
// handle the process crash in their own way.

monarch_hyperactor/src/pytokio.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
use std::error::Error;
1010
use std::future::Future;
11+
use std::ops::Deref;
1112
use std::pin::Pin;
1213

1314
use hyperactor::clock::Clock;
@@ -19,6 +20,7 @@ use pyo3::exceptions::PyStopIteration;
1920
use pyo3::exceptions::PyTimeoutError;
2021
use pyo3::exceptions::PyValueError;
2122
use pyo3::prelude::*;
23+
use pyo3::types::PyType;
2224
use tokio::sync::Mutex;
2325
use tokio::sync::watch;
2426

@@ -119,7 +121,7 @@ where
119121
}
120122

121123
impl PyPythonTask {
122-
fn take_task(
124+
pub(crate) fn take_task(
123125
&mut self,
124126
) -> PyResult<Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>> {
125127
self.inner
@@ -158,7 +160,7 @@ impl PyPythonTask {
158160
signal_safe_block_on(py, task)?
159161
}
160162

161-
fn spawn(&mut self) -> PyResult<PyShared> {
163+
pub(crate) fn spawn(&mut self) -> PyResult<PyShared> {
162164
let (tx, rx) = watch::channel(None);
163165
let task = self.take_task()?;
164166
get_tokio_runtime().spawn(async move {
@@ -266,6 +268,11 @@ impl PyPythonTask {
266268
result.map(|r| (r, index))
267269
})
268270
}
271+
272+
#[classmethod]
273+
fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: PyObject) -> PyObject {
274+
cls.clone().unbind().into()
275+
}
269276
}
270277

271278
#[pyclass(
@@ -277,7 +284,7 @@ pub struct PyShared {
277284
}
278285
#[pymethods]
279286
impl PyShared {
280-
fn task(&mut self) -> PyResult<PyPythonTask> {
287+
pub(crate) fn task(&mut self) -> PyResult<PyPythonTask> {
281288
// watch channels start unchanged, and when a value is sent to them signal
282289
// the receivers `changed` future.
283290
// By cloning the rx before awaiting it,
@@ -306,6 +313,11 @@ impl PyShared {
306313
drop(slf);
307314
signal_safe_block_on(py, task)?
308315
}
316+
317+
#[classmethod]
318+
fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: PyObject) -> PyObject {
319+
cls.clone().unbind().into()
320+
}
309321
}
310322

311323
#[pyfunction]

python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi

Lines changed: 18 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import AsyncIterator, final, NoReturn
9+
from typing import AsyncIterator, final, NoReturn, Optional, Protocol
1010

1111
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
1212
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
@@ -15,127 +15,52 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
1515
PortReceiver,
1616
)
1717
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
18-
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
18+
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
1919
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
2020
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
2121
from typing_extensions import Self
2222

23-
@final
24-
class PythonActorMeshRef:
23+
class ActorMeshProtocol(Protocol):
2524
"""
26-
A reference to a remote actor mesh over which PythonMessages can be sent.
25+
Protocol defining the common interface for actor mesh, mesh ref and _ActorMeshRefImpl.
2726
"""
2827

2928
def cast(
30-
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
31-
) -> None:
32-
"""Cast a message to the selected actors in the mesh."""
33-
...
34-
35-
def slice(self, **kwargs: int | slice[int | None, int | None, int | None]) -> Self:
36-
"""
37-
See PythonActorMeshRef.slice for documentation.
38-
"""
39-
...
40-
41-
def new_with_shape(self, shape: Shape) -> PythonActorMeshRef:
42-
"""
43-
Return a new mesh ref with the given sliced shape. If the provided shape
44-
is not a valid slice of the current shape, an exception will be raised.
45-
"""
46-
...
47-
48-
@property
49-
def shape(self) -> Shape:
50-
"""
51-
The Shape object that describes how the rank of an actor
52-
retrieved with get corresponds to coordinates in the
53-
mesh.
54-
"""
55-
...
29+
self,
30+
message: PythonMessage,
31+
selection: str,
32+
mailbox: Mailbox,
33+
) -> None: ...
34+
def new_with_shape(self, shape: Shape) -> Self: ...
35+
def supervision_event(self) -> "Optional[Shared[Exception]]": ...
36+
def stop(self) -> PythonTask[None]: ...
37+
def initialized(self) -> PythonTask[None]: ...
5638

5739
@final
58-
class PythonActorMesh:
59-
def bind(self) -> PythonActorMeshRef:
60-
"""
61-
Bind this actor mesh. The returned mesh ref can be used to reach the
62-
mesh remotely.
63-
"""
64-
...
65-
66-
def cast(
67-
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
68-
) -> None:
69-
"""
70-
Cast a message to the selected actors in the mesh.
71-
"""
72-
...
73-
74-
def slice(
75-
self, **kwargs: int | slice[int | None, int | None, int | None]
76-
) -> PythonActorMeshRef:
77-
"""
78-
Slice the mesh into a new mesh ref with the given selection. The reason
79-
it returns a mesh ref, rather than the mesh object itself, is because
80-
sliced mesh is a view of the original mesh, and does not own the mesh's
81-
resources.
82-
83-
Arguments:
84-
- `kwargs`: argument name is the label, and argument value is how to
85-
slice the mesh along the dimension of that label.
86-
"""
87-
...
88-
89-
def new_with_shape(self, shape: Shape) -> PythonActorMeshRef:
90-
"""
91-
Return a new mesh ref with the given sliced shape. If the provided shape
92-
is not a valid slice of the current shape, an exception will be raised.
93-
"""
94-
...
40+
class PythonActorMesh(ActorMeshProtocol):
41+
pass
9542

43+
class PythonActorMeshImpl:
9644
def get_supervision_event(self) -> ActorSupervisionEvent | None:
9745
"""
9846
Returns supervision event if there is any.
9947
"""
10048
...
10149

102-
def supervision_event(self) -> PythonTask[Exception]:
103-
"""
104-
Completes with an exception when there is a supervision error.
105-
"""
106-
...
107-
10850
def get(self, rank: int) -> ActorId | None:
10951
"""
11052
Get the actor id for the actor at the given rank.
11153
"""
11254
...
11355

114-
@property
115-
def client(self) -> Mailbox:
116-
"""
117-
A client that can be used to communicate with individual
118-
actors in the mesh, and also to create ports that can be
119-
broadcast across the mesh)
120-
"""
121-
...
122-
123-
@property
124-
def shape(self) -> Shape:
125-
"""
126-
The Shape object that describes how the rank of an actor
127-
retrieved with get corresponds to coordinates in the
128-
mesh.
129-
"""
130-
...
131-
132-
async def stop(self) -> None:
56+
def stop(self) -> PythonTask[None]:
13357
"""
13458
Stop all actors that are part of this mesh.
13559
Using this mesh after stop() is called will raise an Exception.
13660
"""
13761
...
13862

63+
def supervision_event(self) -> "Optional[Shared[Exception]]": ...
13964
@property
14065
def stopped(self) -> bool:
14166
"""

python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66

77
# pyre-strict
88

9-
from typing import AsyncIterator, final, Type
9+
from typing import AsyncIterator, final, Literal, overload, Type
1010

1111
from monarch._rust_bindings.monarch_hyperactor.actor import Actor
12-
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
12+
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import (
13+
PythonActorMesh,
14+
PythonActorMeshImpl,
15+
)
1316

1417
from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc
1518
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
16-
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
19+
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
1720

1821
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
1922

@@ -42,6 +45,10 @@ class ProcMesh:
4245
"""
4346
...
4447

48+
@staticmethod
49+
def spawn_async(
50+
proc_mesh: Shared[ProcMesh], name: str, actor: Type[Actor], emulated: bool
51+
) -> PythonActorMesh: ...
4552
async def monitor(self) -> ProcMeshMonitor:
4653
"""
4754
Returns a supervision monitor for this mesh.

0 commit comments

Comments
 (0)