Skip to content

Move casting protocol into rust #851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
529 changes: 338 additions & 191 deletions monarch_hyperactor/src/actor_mesh.rs

Large diffs are not rendered by default.

54 changes: 50 additions & 4 deletions monarch_hyperactor/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use hyperactor_mesh::shared_cell::SharedCellPool;
use hyperactor_mesh::shared_cell::SharedCellRef;
use monarch_types::PickledPyObject;
use ndslice::Shape;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::PyException;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
Expand All @@ -38,10 +39,13 @@ use tokio::sync::Mutex;
use tokio::sync::mpsc;

use crate::actor_mesh::PythonActorMesh;
use crate::actor_mesh::PythonActorMeshImpl;
use crate::alloc::PyAlloc;
use crate::mailbox::PyMailbox;
use crate::pytokio::PyPythonTask;
use crate::pytokio::PyShared;
use crate::pytokio::PythonTask;
use crate::runtime::get_tokio_runtime;
use crate::shape::PyShape;
use crate::supervision::SupervisionError;
use crate::supervision::Unhealthy;
Expand Down Expand Up @@ -278,21 +282,63 @@ impl PyProcMesh {
let pickled_type = PickledPyObject::pickle(actor.as_any())?;
let proc_mesh = self.try_inner()?;
let keepalive = self.keepalive.clone();
PyPythonTask::new(async move {
let meshimpl = async move {
ensure_mesh_healthy(&unhealthy_event).await?;
let mailbox = proc_mesh.client().clone();
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap();
let im = PythonActorMeshImpl::new(
actor_mesh,
PyMailbox { inner: mailbox },
keepalive,
actor_events,
);
Ok(PythonActorMesh::from_impl(im))
};
PyPythonTask::new(meshimpl)
}

#[staticmethod]
fn spawn_async(
proc_mesh: &mut PyShared,
name: String,
actor: Py<PyType>,
emulated: bool,
) -> PyResult<PyObject> {
let task = proc_mesh.task()?.take_task()?;
let meshimpl = async move {
let proc_mesh = task.await?;
let (proc_mesh, pickled_type, unhealthy_event, keepalive) =
Python::with_gil(|py| -> PyResult<_> {
let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
let slf = slf.borrow();
let unhealthy_event = Arc::clone(&slf.unhealthy_event);
let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
let proc_mesh = slf.try_inner()?;
let keepalive = slf.keepalive.clone();
Ok((proc_mesh, pickled_type, unhealthy_event, keepalive))
})?;
ensure_mesh_healthy(&unhealthy_event).await?;
let mailbox = proc_mesh.client().clone();
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap();
Ok(PythonActorMesh::monitored(
Ok(PythonActorMeshImpl::new(
actor_mesh,
PyMailbox { inner: mailbox },
keepalive,
actor_events,
))
})
};
if emulated {
// we give up on doing mesh spawn async for the emulated old version
// it is too complicated to make both work.
let r = get_tokio_runtime().block_on(meshimpl)?;
Python::with_gil(|py| r.into_py_any(py))
} else {
let r = PythonActorMesh::new(meshimpl);
Python::with_gil(|py| r.into_py_any(py))
}
}

// User can call this to monitor the proc mesh events. This will override
// the default monitor that exits the client on process crash, so user can
// handle the process crash in their own way.
Expand Down
18 changes: 15 additions & 3 deletions monarch_hyperactor/src/pytokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

use std::error::Error;
use std::future::Future;
use std::ops::Deref;
use std::pin::Pin;

use hyperactor::clock::Clock;
Expand All @@ -19,6 +20,7 @@ use pyo3::exceptions::PyStopIteration;
use pyo3::exceptions::PyTimeoutError;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyType;
use tokio::sync::Mutex;
use tokio::sync::watch;

Expand Down Expand Up @@ -119,7 +121,7 @@ where
}

impl PyPythonTask {
fn take_task(
pub(crate) fn take_task(
&mut self,
) -> PyResult<Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>> {
self.inner
Expand Down Expand Up @@ -158,7 +160,7 @@ impl PyPythonTask {
signal_safe_block_on(py, task)?
}

fn spawn(&mut self) -> PyResult<PyShared> {
pub(crate) fn spawn(&mut self) -> PyResult<PyShared> {
let (tx, rx) = watch::channel(None);
let task = self.take_task()?;
get_tokio_runtime().spawn(async move {
Expand Down Expand Up @@ -266,6 +268,11 @@ impl PyPythonTask {
result.map(|r| (r, index))
})
}

#[classmethod]
fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: PyObject) -> PyObject {
cls.clone().unbind().into()
}
}

#[pyclass(
Expand All @@ -277,7 +284,7 @@ pub struct PyShared {
}
#[pymethods]
impl PyShared {
fn task(&mut self) -> PyResult<PyPythonTask> {
pub(crate) fn task(&mut self) -> PyResult<PyPythonTask> {
// watch channels start unchanged, and when a value is sent to them signal
// the receivers `changed` future.
// By cloning the rx before awaiting it,
Expand Down Expand Up @@ -306,6 +313,11 @@ impl PyShared {
drop(slf);
signal_safe_block_on(py, task)?
}

#[classmethod]
fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: PyObject) -> PyObject {
cls.clone().unbind().into()
}
}

#[pyfunction]
Expand Down
111 changes: 18 additions & 93 deletions python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from typing import AsyncIterator, final, NoReturn
from typing import AsyncIterator, final, NoReturn, Optional, Protocol

from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
Expand All @@ -15,127 +15,52 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
PortReceiver,
)
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
from typing_extensions import Self

@final
class PythonActorMeshRef:
class ActorMeshProtocol(Protocol):
"""
A reference to a remote actor mesh over which PythonMessages can be sent.
Protocol defining the common interface for actor mesh, mesh ref and _ActorMeshRefImpl.
"""

def cast(
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
) -> None:
"""Cast a message to the selected actors in the mesh."""
...

def slice(self, **kwargs: int | slice[int | None, int | None, int | None]) -> Self:
"""
See PythonActorMeshRef.slice for documentation.
"""
...

def new_with_shape(self, shape: Shape) -> PythonActorMeshRef:
"""
Return a new mesh ref with the given sliced shape. If the provided shape
is not a valid slice of the current shape, an exception will be raised.
"""
...

@property
def shape(self) -> Shape:
"""
The Shape object that describes how the rank of an actor
retrieved with get corresponds to coordinates in the
mesh.
"""
...
self,
message: PythonMessage,
selection: str,
mailbox: Mailbox,
) -> None: ...
def new_with_shape(self, shape: Shape) -> Self: ...
def supervision_event(self) -> "Optional[Shared[Exception]]": ...
def stop(self) -> PythonTask[None]: ...
def initialized(self) -> PythonTask[None]: ...

@final
class PythonActorMesh:
def bind(self) -> PythonActorMeshRef:
"""
Bind this actor mesh. The returned mesh ref can be used to reach the
mesh remotely.
"""
...

def cast(
self, mailbox: Mailbox, selection: Selection, message: PythonMessage
) -> None:
"""
Cast a message to the selected actors in the mesh.
"""
...

def slice(
self, **kwargs: int | slice[int | None, int | None, int | None]
) -> PythonActorMeshRef:
"""
Slice the mesh into a new mesh ref with the given selection. The reason
it returns a mesh ref, rather than the mesh object itself, is because
sliced mesh is a view of the original mesh, and does not own the mesh's
resources.

Arguments:
- `kwargs`: argument name is the label, and argument value is how to
slice the mesh along the dimension of that label.
"""
...

def new_with_shape(self, shape: Shape) -> PythonActorMeshRef:
"""
Return a new mesh ref with the given sliced shape. If the provided shape
is not a valid slice of the current shape, an exception will be raised.
"""
...
class PythonActorMesh(ActorMeshProtocol):
pass

class PythonActorMeshImpl:
def get_supervision_event(self) -> ActorSupervisionEvent | None:
"""
Returns supervision event if there is any.
"""
...

def supervision_event(self) -> PythonTask[Exception]:
"""
Completes with an exception when there is a supervision error.
"""
...

def get(self, rank: int) -> ActorId | None:
"""
Get the actor id for the actor at the given rank.
"""
...

@property
def client(self) -> Mailbox:
"""
A client that can be used to communicate with individual
actors in the mesh, and also to create ports that can be
broadcast across the mesh)
"""
...

@property
def shape(self) -> Shape:
"""
The Shape object that describes how the rank of an actor
retrieved with get corresponds to coordinates in the
mesh.
"""
...

async def stop(self) -> None:
def stop(self) -> PythonTask[None]:
"""
Stop all actors that are part of this mesh.
Using this mesh after stop() is called will raise an Exception.
"""
...

def supervision_event(self) -> "Optional[Shared[Exception]]": ...
@property
def stopped(self) -> bool:
"""
Expand Down
13 changes: 10 additions & 3 deletions python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

# pyre-strict

from typing import AsyncIterator, final, Type
from typing import AsyncIterator, final, Literal, overload, Type

from monarch._rust_bindings.monarch_hyperactor.actor import Actor
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import (
PythonActorMesh,
PythonActorMeshImpl,
)

from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared

from monarch._rust_bindings.monarch_hyperactor.shape import Shape

Expand Down Expand Up @@ -42,6 +45,10 @@ class ProcMesh:
"""
...

@staticmethod
def spawn_async(
proc_mesh: Shared[ProcMesh], name: str, actor: Type[Actor], emulated: bool
) -> PythonActorMesh: ...
async def monitor(self) -> ProcMeshMonitor:
"""
Returns a supervision monitor for this mesh.
Expand Down
Loading
Loading