diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index 41bc82061..a18aeec7a 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -6,9 +6,15 @@ * LICENSE file in the root directory of this source tree. */ +use std::error::Error; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; +use futures::future::FutureExt; +use futures::future::Shared; use hyperactor::ActorRef; +use hyperactor::Mailbox; use hyperactor::id; use hyperactor::supervision::ActorSupervisionEvent; use hyperactor_mesh::Mesh; @@ -16,35 +22,149 @@ use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::actor_mesh::ActorSupervisionEvents; use hyperactor_mesh::reference::ActorMeshRef; +use hyperactor_mesh::sel; use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellRef; +use ndslice::Selection; +use pyo3::IntoPyObjectExt; use pyo3::exceptions::PyException; use pyo3::exceptions::PyNotImplementedError; use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyBytes; -use pyo3::types::PyDict; -use pyo3::types::PySlice; use serde::Deserialize; use serde::Serialize; +use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::mpsc::unbounded_channel; use crate::actor::PythonActor; use crate::actor::PythonMessage; +use crate::actor::PythonMessageKind; +use crate::mailbox::EitherPortRef; use crate::mailbox::PyMailbox; use crate::proc::PyActorId; use crate::proc_mesh::Keepalive; use crate::pytokio::PyPythonTask; -use crate::selection::PySelection; +use crate::pytokio::PyShared; +use crate::runtime::get_tokio_runtime; use crate::shape::PyShape; use crate::supervision::SupervisionError; use crate::supervision::Unhealthy; +/// Trait defining the common interface for actor mesh, mesh ref and actor mesh implementations. +/// This corresponds to the Python ActorMeshProtocol ABC. +trait ActorMeshProtocol: Send + Sync { + /// Cast a message to actors selected by the given selection using the specified mailbox. + fn cast(&self, message: PythonMessage, selection: Selection, mailbox: Mailbox) -> PyResult<()>; + + /// Create a new actor mesh with the specified shape. + fn new_with_shape(&self, shape: PyShape) -> PyResult>; + + fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>; + + /// Get supervision events for this actor mesh. + /// Returns None by default for implementations that don't support supervision events. + fn supervision_event(&self) -> PyResult> { + Ok(None) + } + + /// Stop the actor mesh asynchronously. + /// Default implementation raises NotImplementedError for types that don't support stopping. + fn stop(&self) -> PyResult { + Err(PyNotImplementedError::new_err(format!( + "stop() is not supported for {}", + std::any::type_name::() + ))) + } + + /// Initialize the actor mesh asynchronously. + /// Default implementation returns None (no initialization needed). + fn initialized<'py>(&self) -> PyResult { + PyPythonTask::new(async { Ok(None::<()>) }) + } +} + +/// This just forwards to the rust trait that can implement these bindings #[pyclass( name = "PythonActorMesh", module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" )] -pub struct PythonActorMesh { +pub(crate) struct PythonActorMesh { + inner: Box, +} + +impl PythonActorMesh { + pub(crate) fn new(f: F) -> Self + where + F: Future> + Send + 'static, + { + PythonActorMesh { + inner: Box::new(AsyncActorMesh::new_queue(async { + let b: Box = Box::new(f.await?); + Ok(b) + })), + } + } + pub(crate) fn from_impl(im: PythonActorMeshImpl) -> Self { + PythonActorMesh { + inner: Box::new(im), + } + } +} + +fn to_hy_sel(selection: &str) -> PyResult { + match selection { + "choose" => Ok(sel!(?)), + "all" => Ok(sel!(*)), + _ => Err(PyErr::new::(format!( + "Invalid selection: {}", + selection + ))), + } +} + +#[pymethods] +impl PythonActorMesh { + fn cast(&self, message: &PythonMessage, selection: &str, mailbox: &PyMailbox) -> PyResult<()> { + let sel = to_hy_sel(selection)?; + self.inner.cast(message.clone(), sel, mailbox.inner.clone()) + } + + fn new_with_shape(&self, shape: PyShape) -> PyResult { + let inner = self.inner.new_with_shape(shape)?; + Ok(PythonActorMesh { inner }) + } + + fn supervision_event(&self) -> PyResult> { + self.inner.supervision_event() + } + + fn stop(&self) -> PyResult { + self.inner.stop() + } + + fn initialized(&self) -> PyResult { + self.inner.initialized() + } + + fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> { + self.inner.__reduce__(py) + } + + #[staticmethod] + fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult { + let r: PyResult = bincode::deserialize(bytes.as_bytes()) + .map_err(|e| PyErr::new::(e.to_string())); + r.map(|r| PythonActorMesh { inner: Box::new(r) }) + } +} + +#[pyclass( + name = "PythonActorMeshImpl", + module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" +)] +pub(crate) struct PythonActorMeshImpl { inner: SharedCell>, client: PyMailbox, _keepalive: Keepalive, @@ -53,10 +173,10 @@ pub struct PythonActorMesh { monitor: tokio::task::JoinHandle<()>, } -impl PythonActorMesh { +impl PythonActorMeshImpl { /// Create a new [`PythonActorMesh`] with a monitor that will observe supervision /// errors for this mesh, and update its state properly. - pub(crate) fn monitored( + pub(crate) fn new( inner: SharedCell>, client: PyMailbox, keepalive: Keepalive, @@ -65,12 +185,12 @@ impl PythonActorMesh { let (user_monitor_sender, _) = tokio::sync::broadcast::channel::>(1); let unhealthy_event = Arc::new(std::sync::Mutex::new(Unhealthy::SoFarSoGood)); - let monitor = tokio::spawn(Self::actor_mesh_monitor( + let monitor = tokio::spawn(PythonActorMeshImpl::actor_mesh_monitor( events, user_monitor_sender.clone(), Arc::clone(&unhealthy_event), )); - Self { + PythonActorMeshImpl { inner, client, _keepalive: keepalive, @@ -79,7 +199,6 @@ impl PythonActorMesh { monitor, } } - /// Monitor of the actor mesh. It processes supervision errors for the mesh, and keeps mesh /// health state up to date. async fn actor_mesh_monitor( @@ -115,23 +234,14 @@ impl PythonActorMesh { }) } - fn pickling_err(&self) -> PyErr { - PyErr::new::( - "PythonActorMesh cannot be pickled. If applicable, use bind() \ - to get a PythonActorMeshRef, and use that instead." - .to_string(), - ) + fn bind(&self) -> PyResult { + let mesh = self.try_inner()?; + Ok(PythonActorMeshRef { inner: mesh.bind() }) } } -#[pymethods] -impl PythonActorMesh { - fn cast( - &self, - mailbox: &PyMailbox, - selection: &PySelection, - message: &PythonMessage, - ) -> PyResult<()> { +impl ActorMeshProtocol for PythonActorMeshImpl { + fn cast(&self, message: PythonMessage, selection: Selection, mailbox: Mailbox) -> PyResult<()> { let unhealthy_event = self .unhealthy_event .lock() @@ -153,44 +263,11 @@ impl PythonActorMesh { } self.try_inner()? - .cast(&mailbox.inner, selection.inner().clone(), message.clone()) + .cast(&mailbox, selection, message.clone()) .map_err(|err| PyException::new_err(err.to_string()))?; Ok(()) } - - fn bind(&self) -> PyResult { - let mesh = self.try_inner()?; - Ok(PythonActorMeshRef { inner: mesh.bind() }) - } - - fn get_supervision_event(&self) -> PyResult> { - let unhealthy_event = self - .unhealthy_event - .lock() - .expect("failed to acquire unhealthy_event lock"); - - match &*unhealthy_event { - Unhealthy::SoFarSoGood => Ok(None), - Unhealthy::StreamClosed => Ok(Some(PyActorSupervisionEvent { - // Dummy actor as place holder to indicate the whole mesh is stopped - // TODO(albertli): remove this when pushing all supervision logic to rust. - actor_id: id!(default[0].actor[0]).into(), - actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(), - })), - Unhealthy::Crashed(event) => Ok(Some(PyActorSupervisionEvent::from(event.clone()))), - } - } - - // Consider defining a "PythonActorRef", which carries specifically - // a reference to python message actors. - fn get(&self, rank: usize) -> PyResult> { - Ok(self - .try_inner()? - .get(rank) - .map(ActorRef::into_actor_id) - .map(PyActorId::from)) - } - fn supervision_event(&self) -> PyResult { + fn supervision_event(&self) -> PyResult> { let mut receiver = self.user_monitor_sender.subscribe(); PyPythonTask::new(async move { let event = receiver.recv().await; @@ -208,36 +285,12 @@ impl PythonActorMesh { event.actor_id, event.actor_status ))) }) + .map(|mut x| x.spawn().map(Some))? } - - #[pyo3(signature = (**kwargs))] - fn slice(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult { - self.bind()?.slice(kwargs) - } - - fn new_with_shape(&self, shape: PyShape) -> PyResult { + fn new_with_shape(&self, shape: PyShape) -> PyResult> { self.bind()?.new_with_shape(shape) } - #[getter] - pub fn client(&self) -> PyMailbox { - self.client.clone() - } - - #[getter] - fn shape(&self) -> PyResult { - Ok(PyShape::from(self.try_inner()?.shape().clone())) - } - - // Override the pickling methods to provide a meaningful error message. - fn __reduce__(&self) -> PyResult<()> { - Err(self.pickling_err()) - } - - fn __reduce_ex__(&self, _proto: u8) -> PyResult<()> { - Err(self.pickling_err()) - } - fn stop<'py>(&self) -> PyResult { let actor_mesh = self.inner.clone(); PyPythonTask::new(async move { @@ -251,6 +304,46 @@ impl PythonActorMesh { Ok(()) }) } + fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> { + self.bind()?.__reduce__(py) + } +} + +#[pymethods] +impl PythonActorMeshImpl { + fn get_supervision_event(&self) -> PyResult> { + let unhealthy_event = self + .unhealthy_event + .lock() + .expect("failed to acquire unhealthy_event lock"); + + match &*unhealthy_event { + Unhealthy::SoFarSoGood => Ok(None), + Unhealthy::StreamClosed => Ok(Some(PyActorSupervisionEvent { + // Dummy actor as place holder to indicate the whole mesh is stopped + // TODO(albertli): remove this when pushing all supervision logic to rust. + actor_id: id!(default[0].actor[0]).into(), + actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(), + })), + Unhealthy::Crashed(event) => Ok(Some(PyActorSupervisionEvent::from(event.clone()))), + } + } + + fn supervision_event(&self) -> PyResult> { + ActorMeshProtocol::supervision_event(self) + } + fn stop(&self) -> PyResult { + ActorMeshProtocol::stop(self) + } + // Consider defining a "PythonActorRef", which carries specifically + // a reference to python message actors. + fn get(&self, rank: usize) -> PyResult> { + Ok(self + .try_inner()? + .get(rank) + .map(ActorRef::into_actor_id) + .map(PyActorId::from)) + } #[getter] fn stopped(&self) -> PyResult { @@ -258,142 +351,196 @@ impl PythonActorMesh { } } -#[pyclass( - frozen, - name = "PythonActorMeshRef", - module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" -)] #[derive(Debug, Serialize, Deserialize)] -pub(super) struct PythonActorMeshRef { +struct PythonActorMeshRef { inner: ActorMeshRef, } -#[pymethods] -impl PythonActorMeshRef { - fn cast( - &self, - client: &PyMailbox, - selection: &PySelection, - message: &PythonMessage, - ) -> PyResult<()> { +impl ActorMeshProtocol for PythonActorMeshRef { + fn cast(&self, message: PythonMessage, selection: Selection, client: Mailbox) -> PyResult<()> { self.inner - .cast(&client.inner, selection.inner().clone(), message.clone()) + .cast(&client, selection, message.clone()) .map_err(|err| PyException::new_err(err.to_string()))?; Ok(()) } - #[pyo3(signature = (**kwargs))] - fn slice(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult { - // When the input type is `int`, convert it into `ndslice::Range`. - fn convert_int(index: isize) -> PyResult { - if index < 0 { - return Err(PyException::new_err(format!( - "does not support negative index in selection: {}", - index - ))); - } - Ok(ndslice::Range::from(index as usize)) - } - - // When the input type is `slice`, convert it into `ndslice::Range`. - fn convert_py_slice<'py>(s: &Bound<'py, PySlice>) -> PyResult { - fn get_attr<'py>(s: &Bound<'py, PySlice>, attr: &str) -> PyResult> { - let v = s.getattr(attr)?.extract::>()?; - if v.is_some() && v.unwrap() < 0 { - return Err(PyException::new_err(format!( - "does not support negative {} in slice: {}", - attr, - v.unwrap(), - ))); - } - Ok(v) - } + fn new_with_shape(&self, shape: PyShape) -> PyResult> { + let sliced = self + .inner + .new_with_shape(shape.get_inner().clone()) + .map_err(|e| PyErr::new::(e.to_string()))?; + Ok(Box::new(Self { inner: sliced })) + } - let start = get_attr(s, "start")?.unwrap_or(0); - let stop: Option = get_attr(s, "stop")?; - let step = get_attr(s, "step")?.unwrap_or(1); - Ok(ndslice::Range( - start as usize, - stop.map(|s| s as usize), - step as usize, - )) - } + fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> { + let bytes = + bincode::serialize(self).map_err(|e| PyErr::new::(e.to_string()))?; + let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap(); + let module = py + .import("monarch._rust_bindings.monarch_hyperactor.actor_mesh") + .unwrap(); + let from_bytes = module + .getattr("PythonActorMesh") + .unwrap() + .getattr("from_bytes") + .unwrap(); + Ok((from_bytes, py_bytes)) + } +} - if kwargs.is_none() || kwargs.unwrap().is_empty() { - return Err(PyException::new_err("selection cannot be empty")); +impl Drop for PythonActorMeshImpl { + fn drop(&mut self) { + if let Ok(mesh) = self.inner.borrow() { + tracing::debug!("Dropping PythonActorMesh: {}", mesh.name()); + } else { + tracing::debug!( + "Dropping stopped PythonActorMesh. The underlying mesh is already stopped." + ); } + self.monitor.abort(); + } +} +struct ClonePyErr { + inner: PyErr, +} - let mut sliced = self.inner.clone(); +impl From for PyErr { + fn from(value: ClonePyErr) -> PyErr { + value.inner + } +} +impl From for ClonePyErr { + fn from(inner: PyErr) -> ClonePyErr { + ClonePyErr { inner } + } +} - for entry in kwargs.unwrap().items() { - let label = entry.get_item(0)?.str()?; - let label_str = label.to_str()?; +impl Clone for ClonePyErr { + fn clone(&self) -> Self { + Python::with_gil(|py| self.inner.clone_ref(py).into()) + } +} - let value = entry.get_item(1)?; +type ActorMeshResult = Result, ClonePyErr>; +struct AsyncActorMesh { + mesh: Shared + Send>>>, + queue: UnboundedSender + Send + 'static>>>, + supervised: bool, +} - let range = if let Ok(index) = value.extract::() { - convert_int(index)? - } else if let Ok(s) = value.downcast::() { - convert_py_slice(s)? - } else { - return Err(PyException::new_err( - "selection only supports type int or slice", - )); - }; - sliced = sliced.select(label_str, range).map_err(|err| { - PyException::new_err(format!( - "failed to select label {}; error is: {}", - label_str, err - )) - })?; +impl AsyncActorMesh { + fn new_queue(f: F) -> AsyncActorMesh + where + F: Future>> + Send + 'static, + { + let (queue, mut recv) = unbounded_channel(); + + get_tokio_runtime().spawn(async move { + loop { + let r = recv.recv().await; + if let Some(r) = r { + r.await; + } else { + return; + } + } + }); + AsyncActorMesh::new(queue, true, f) + } + fn new( + queue: UnboundedSender + Send + 'static>>>, + supervised: bool, + f: F, + ) -> AsyncActorMesh + where + F: Future>> + Send + 'static, + { + let mesh = async { Ok(Arc::from(f.await?)) }.boxed().shared(); + AsyncActorMesh { + mesh, + queue, + supervised, } + } - Ok(Self { inner: sliced }) + fn push(&self, f: F) + where + F: Future + Send + 'static, + { + self.queue.send(f.boxed()).unwrap(); } +} - fn new_with_shape(&self, shape: PyShape) -> PyResult { - let sliced = self - .inner - .new_with_shape(shape.get_inner().clone()) - .map_err(|e| PyErr::new::(e.to_string()))?; - Ok(Self { inner: sliced }) +impl ActorMeshProtocol for AsyncActorMesh { + fn cast(&self, message: PythonMessage, selection: Selection, client: Mailbox) -> PyResult<()> { + let mesh = self.mesh.clone(); + self.push(async { + let port = match &message.kind { + PythonMessageKind::CallMethod { response_port, .. } => response_port.clone(), + _ => None, + }; + let result = async { mesh.await?.cast(message, selection, client.clone()) }.await; + match (port, result) { + (Some(p), Err(pyerr)) => Python::with_gil(|py: Python<'_>| { + let port_ref = match p { + EitherPortRef::Once(p) => p.into_bound_py_any(py), + EitherPortRef::Unbounded(p) => p.into_bound_py_any(py), + } + .unwrap(); + let port = py + .import("monarch._src.actor.actor_mesh") + .unwrap() + .call_method1("Port", (port_ref, PyMailbox { inner: client }, 0)) + .unwrap(); + port.call_method1("exception", (pyerr.value(py),)).unwrap(); + }), + _ => (), + } + }); + Ok(()) } - #[getter] - fn shape(&self) -> PyShape { - PyShape::from(self.inner.shape().clone()) + fn new_with_shape(&self, shape: PyShape) -> PyResult> { + let mesh = self.mesh.clone(); + Ok(Box::new(AsyncActorMesh::new( + self.queue.clone(), + false, + async { Ok(mesh.await?.new_with_shape(shape)?) }, + ))) } - #[staticmethod] - fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult { - bincode::deserialize(bytes.as_bytes()) - .map_err(|e| PyErr::new::(e.to_string())) + fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> { + let mesh = self.mesh.clone(); + let mesh = py.allow_threads(|| get_tokio_runtime().block_on(mesh)); + mesh?.__reduce__(py) } - fn __reduce__<'py>( - slf: &Bound<'py, Self>, - ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> { - let bytes = bincode::serialize(&*slf.borrow()) - .map_err(|e| PyErr::new::(e.to_string()))?; - let py_bytes = PyBytes::new(slf.py(), &bytes); - Ok((slf.as_any().getattr("from_bytes")?, (py_bytes,))) + fn supervision_event(&self) -> PyResult> { + if !self.supervised { + return Ok(None); + } + let mesh = self.mesh.clone(); + PyPythonTask::new(async { + let mut event = mesh.await?.supervision_event()?.unwrap(); + event.task()?.take_task()?.await + }) + .map(|mut x| x.spawn().map(Some))? } - fn __repr__(&self) -> String { - format!("{:?}", self) + fn stop(&self) -> PyResult { + let mesh = self.mesh.clone(); + PyPythonTask::new(async { + let task = mesh.await?.stop()?.take_task()?; + task.await + }) } -} -impl Drop for PythonActorMesh { - fn drop(&mut self) { - if let Ok(mesh) = self.inner.borrow() { - tracing::debug!("Dropping PythonActorMesh: {}", mesh.name()); - } else { - tracing::debug!( - "Dropping stopped PythonActorMesh. The underlying mesh is already stopped." - ); - } - self.monitor.abort(); + fn initialized<'py>(&self) -> PyResult { + let mesh = self.mesh.clone(); + PyPythonTask::new(async { + mesh.await?; + Ok(None::<()>) + }) } } @@ -433,7 +580,7 @@ impl From for PyActorSupervisionEvent { pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; - hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) } diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index f66d68cc2..1e2837f77 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -29,6 +29,7 @@ use hyperactor_mesh::shared_cell::SharedCellPool; use hyperactor_mesh::shared_cell::SharedCellRef; use monarch_types::PickledPyObject; use ndslice::Shape; +use pyo3::IntoPyObjectExt; use pyo3::exceptions::PyException; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; @@ -38,10 +39,13 @@ use tokio::sync::Mutex; use tokio::sync::mpsc; use crate::actor_mesh::PythonActorMesh; +use crate::actor_mesh::PythonActorMeshImpl; use crate::alloc::PyAlloc; use crate::mailbox::PyMailbox; use crate::pytokio::PyPythonTask; +use crate::pytokio::PyShared; use crate::pytokio::PythonTask; +use crate::runtime::get_tokio_runtime; use crate::shape::PyShape; use crate::supervision::SupervisionError; use crate::supervision::Unhealthy; @@ -278,21 +282,63 @@ impl PyProcMesh { let pickled_type = PickledPyObject::pickle(actor.as_any())?; let proc_mesh = self.try_inner()?; let keepalive = self.keepalive.clone(); - PyPythonTask::new(async move { + let meshimpl = async move { ensure_mesh_healthy(&unhealthy_event).await?; + let mailbox = proc_mesh.client().clone(); + let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; + let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); + let im = PythonActorMeshImpl::new( + actor_mesh, + PyMailbox { inner: mailbox }, + keepalive, + actor_events, + ); + Ok(PythonActorMesh::from_impl(im)) + }; + PyPythonTask::new(meshimpl) + } + #[staticmethod] + fn spawn_async( + proc_mesh: &mut PyShared, + name: String, + actor: Py, + emulated: bool, + ) -> PyResult { + let task = proc_mesh.task()?.take_task()?; + let meshimpl = async move { + let proc_mesh = task.await?; + let (proc_mesh, pickled_type, unhealthy_event, keepalive) = + Python::with_gil(|py| -> PyResult<_> { + let slf: Bound = proc_mesh.extract(py)?; + let slf = slf.borrow(); + let unhealthy_event = Arc::clone(&slf.unhealthy_event); + let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?; + let proc_mesh = slf.try_inner()?; + let keepalive = slf.keepalive.clone(); + Ok((proc_mesh, pickled_type, unhealthy_event, keepalive)) + })?; + ensure_mesh_healthy(&unhealthy_event).await?; let mailbox = proc_mesh.client().clone(); let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); - Ok(PythonActorMesh::monitored( + Ok(PythonActorMeshImpl::new( actor_mesh, PyMailbox { inner: mailbox }, keepalive, actor_events, )) - }) + }; + if emulated { + // we give up on doing mesh spawn async for the emulated old version + // it is too complicated to make both work. + let r = get_tokio_runtime().block_on(meshimpl)?; + Python::with_gil(|py| r.into_py_any(py)) + } else { + let r = PythonActorMesh::new(meshimpl); + Python::with_gil(|py| r.into_py_any(py)) + } } - // User can call this to monitor the proc mesh events. This will override // the default monitor that exits the client on process crash, so user can // handle the process crash in their own way. diff --git a/monarch_hyperactor/src/pytokio.rs b/monarch_hyperactor/src/pytokio.rs index 8f5887987..0cf23629a 100644 --- a/monarch_hyperactor/src/pytokio.rs +++ b/monarch_hyperactor/src/pytokio.rs @@ -8,6 +8,7 @@ use std::error::Error; use std::future::Future; +use std::ops::Deref; use std::pin::Pin; use hyperactor::clock::Clock; @@ -19,6 +20,7 @@ use pyo3::exceptions::PyStopIteration; use pyo3::exceptions::PyTimeoutError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::types::PyType; use tokio::sync::Mutex; use tokio::sync::watch; @@ -119,7 +121,7 @@ where } impl PyPythonTask { - fn take_task( + pub(crate) fn take_task( &mut self, ) -> PyResult, PyErr>> + Send + 'static>>> { self.inner @@ -158,7 +160,7 @@ impl PyPythonTask { signal_safe_block_on(py, task)? } - fn spawn(&mut self) -> PyResult { + pub(crate) fn spawn(&mut self) -> PyResult { let (tx, rx) = watch::channel(None); let task = self.take_task()?; get_tokio_runtime().spawn(async move { @@ -266,6 +268,11 @@ impl PyPythonTask { result.map(|r| (r, index)) }) } + + #[classmethod] + fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: PyObject) -> PyObject { + cls.clone().unbind().into() + } } #[pyclass( @@ -277,7 +284,7 @@ pub struct PyShared { } #[pymethods] impl PyShared { - fn task(&mut self) -> PyResult { + pub(crate) fn task(&mut self) -> PyResult { // watch channels start unchanged, and when a value is sent to them signal // the receivers `changed` future. // By cloning the rx before awaiting it, @@ -306,6 +313,11 @@ impl PyShared { drop(slf); signal_safe_block_on(py, task)? } + + #[classmethod] + fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: PyObject) -> PyObject { + cls.clone().unbind().into() + } } #[pyfunction] diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi index c30f54e08..7ce339179 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi @@ -6,7 +6,7 @@ # pyre-strict -from typing import AsyncIterator, final, NoReturn +from typing import AsyncIterator, final, NoReturn, Optional, Protocol from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage from monarch._rust_bindings.monarch_hyperactor.mailbox import ( @@ -15,127 +15,52 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import ( PortReceiver, ) from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask +from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared from monarch._rust_bindings.monarch_hyperactor.selection import Selection from monarch._rust_bindings.monarch_hyperactor.shape import Shape from typing_extensions import Self -@final -class PythonActorMeshRef: +class ActorMeshProtocol(Protocol): """ - A reference to a remote actor mesh over which PythonMessages can be sent. + Protocol defining the common interface for actor mesh, mesh ref and _ActorMeshRefImpl. """ def cast( - self, mailbox: Mailbox, selection: Selection, message: PythonMessage - ) -> None: - """Cast a message to the selected actors in the mesh.""" - ... - - def slice(self, **kwargs: int | slice[int | None, int | None, int | None]) -> Self: - """ - See PythonActorMeshRef.slice for documentation. - """ - ... - - def new_with_shape(self, shape: Shape) -> PythonActorMeshRef: - """ - Return a new mesh ref with the given sliced shape. If the provided shape - is not a valid slice of the current shape, an exception will be raised. - """ - ... - - @property - def shape(self) -> Shape: - """ - The Shape object that describes how the rank of an actor - retrieved with get corresponds to coordinates in the - mesh. - """ - ... + self, + message: PythonMessage, + selection: str, + mailbox: Mailbox, + ) -> None: ... + def new_with_shape(self, shape: Shape) -> Self: ... + def supervision_event(self) -> "Optional[Shared[Exception]]": ... + def stop(self) -> PythonTask[None]: ... + def initialized(self) -> PythonTask[None]: ... @final -class PythonActorMesh: - def bind(self) -> PythonActorMeshRef: - """ - Bind this actor mesh. The returned mesh ref can be used to reach the - mesh remotely. - """ - ... - - def cast( - self, mailbox: Mailbox, selection: Selection, message: PythonMessage - ) -> None: - """ - Cast a message to the selected actors in the mesh. - """ - ... - - def slice( - self, **kwargs: int | slice[int | None, int | None, int | None] - ) -> PythonActorMeshRef: - """ - Slice the mesh into a new mesh ref with the given selection. The reason - it returns a mesh ref, rather than the mesh object itself, is because - sliced mesh is a view of the original mesh, and does not own the mesh's - resources. - - Arguments: - - `kwargs`: argument name is the label, and argument value is how to - slice the mesh along the dimension of that label. - """ - ... - - def new_with_shape(self, shape: Shape) -> PythonActorMeshRef: - """ - Return a new mesh ref with the given sliced shape. If the provided shape - is not a valid slice of the current shape, an exception will be raised. - """ - ... +class PythonActorMesh(ActorMeshProtocol): + pass +class PythonActorMeshImpl: def get_supervision_event(self) -> ActorSupervisionEvent | None: """ Returns supervision event if there is any. """ ... - def supervision_event(self) -> PythonTask[Exception]: - """ - Completes with an exception when there is a supervision error. - """ - ... - def get(self, rank: int) -> ActorId | None: """ Get the actor id for the actor at the given rank. """ ... - @property - def client(self) -> Mailbox: - """ - A client that can be used to communicate with individual - actors in the mesh, and also to create ports that can be - broadcast across the mesh) - """ - ... - - @property - def shape(self) -> Shape: - """ - The Shape object that describes how the rank of an actor - retrieved with get corresponds to coordinates in the - mesh. - """ - ... - - async def stop(self) -> None: + def stop(self) -> PythonTask[None]: """ Stop all actors that are part of this mesh. Using this mesh after stop() is called will raise an Exception. """ ... + def supervision_event(self) -> "Optional[Shared[Exception]]": ... @property def stopped(self) -> bool: """ diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi index b73dc052b..435f4ed2f 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi @@ -6,14 +6,17 @@ # pyre-strict -from typing import AsyncIterator, final, Type +from typing import AsyncIterator, final, Literal, overload, Type from monarch._rust_bindings.monarch_hyperactor.actor import Actor -from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh +from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ( + PythonActorMesh, + PythonActorMeshImpl, +) from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox -from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask +from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared from monarch._rust_bindings.monarch_hyperactor.shape import Shape @@ -42,6 +45,10 @@ class ProcMesh: """ ... + @staticmethod + def spawn_async( + proc_mesh: Shared[ProcMesh], name: str, actor: Type[Actor], emulated: bool + ) -> PythonActorMesh: ... async def monitor(self) -> ProcMeshMonitor: """ Returns a supervision monitor for this mesh. diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 33f2dd48b..c787a795d 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -50,7 +50,7 @@ from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ( PythonActorMesh, - PythonActorMeshRef, + PythonActorMeshImpl, ) from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, @@ -96,6 +96,7 @@ if TYPE_CHECKING: from monarch._rust_bindings.monarch_hyperactor.actor import PortProtocol + from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ActorMeshProtocol from monarch._rust_bindings.monarch_hyperactor.mailbox import PortReceiverBase from monarch._src.actor.proc_mesh import ProcMesh @@ -170,147 +171,17 @@ def set(debug_context: "DebugContext") -> None: _load_balancing_seed = random.Random(4) -def to_hy_sel(selection: Selection) -> HySelection: - if selection == "choose": - return HySelection.any() - elif selection == "all": - return HySelection.all() - else: - raise ValueError(f"invalid selection: {selection}") - - -# A temporary gate used by the PythonActorMesh/PythonActorMeshRef migration. -# We can use this gate to quickly roll back to using _ActorMeshRefImpl, if we -# encounter any issues with the migration. -# -# This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is -# working correctly in production. -def _use_standin_mesh() -> bool: - return bool(os.getenv("USE_STANDIN_ACTOR_MESH", default=False)) - - -class ActorMeshProtocol(ABC): - """ - Protocol defining the common interface for actor mesh, mesh ref and _ActorMeshRefImpl. - """ - - @abstractmethod - def cast( - self, - message: PythonMessage, - selection: Selection, - mailbox: Mailbox, - ) -> None: ... - - @abstractmethod - def new_with_shape(self, shape: Shape) -> Self: ... - - def supervision_event(self) -> "Optional[Shared[Exception]]": - return None - - async def stop(self) -> None: - raise NotImplementedError(f"stop() is not supported for {type(self)}") - - async def initialized(self): - return None - - -class _PythonActorMeshAdapter(ActorMeshProtocol): - """ - Adapter for PythonActorMesh to implement the normalized ActorMeshProtocol - interface. This adapter also provides a convenient way to add states to - the mesh on the python side, without changing the rust side implementation. - - Since PythonActorMesh cannot be pickled, this adapter also provides a - custom pickling logic which bind the mesh to PythonActorMeshRef during - pickling. - """ - - def __init__(self, inner: PythonActorMesh) -> None: - if _use_standin_mesh(): - raise ValueError( - "_PythonActorMeshAdapter should only be used when USE_STANDIN_ACTOR_MESH is not set" - ) - self._inner = inner - - def cast( - self, - message: PythonMessage, - selection: Selection, - mailbox: Mailbox, - ) -> None: - self._inner.cast(mailbox, to_hy_sel(selection), message) - - def new_with_shape(self, shape: Shape) -> "ActorMeshProtocol": - sliced: PythonActorMeshRef = self._inner.new_with_shape(shape) - return _PythonActorMeshRefAdapter(sliced) - - def supervision_event(self) -> "Optional[Shared[Exception]]": - return self._inner.supervision_event().spawn() - - async def stop(self) -> None: - await self._inner.stop() - - def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: - """ - Automatically pickle as a PythonActorMeshRef by binding the mesh. - Unpicklable states such as proc_mesh are dropped as well. - """ - mesh_ref = self._inner.bind() - return _PythonActorMeshRefAdapter, (mesh_ref,) - - -class _PythonActorMeshRefAdapter(ActorMeshProtocol): - """ - Adapter for PythonActorMeshRef to implement the normalized ActorMeshProtocol interface. It is - also used to store unpickable states such as proc_mesh. - """ - - def __init__(self, inner: PythonActorMeshRef) -> None: - if _use_standin_mesh(): - raise ValueError( - "_PythonActorMeshRefAdapter should only be used when USE_STANDIN_ACTOR_MESH is not set" - ) - self._inner = inner - - def cast( - self, - message: PythonMessage, - selection: Selection, - mailbox: Mailbox, - ) -> None: - self._inner.cast(mailbox, to_hy_sel(selection), message) - - def new_with_shape(self, shape: Shape) -> "ActorMeshProtocol": - sliced: PythonActorMeshRef = self._inner.new_with_shape(shape) - return _PythonActorMeshRefAdapter(sliced) - - def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: - """ - Dropping all unpickable states. - """ - return _PythonActorMeshRefAdapter, (self._inner,) - - -class _SingletonActorAdapator(ActorMeshProtocol): +class _SingletonActorAdapator: def __init__(self, inner: ActorId, shape: Optional[Shape] = None) -> None: self._inner: ActorId = inner if shape is None: shape = singleton_shape self._shape = shape - @property - def shape(self) -> Shape: - return singleton_shape - - @property - def proc_mesh(self) -> Optional["ProcMesh"]: - return None - def cast( self, message: PythonMessage, - selection: Selection, + selection: str, mailbox: Mailbox, ) -> None: mailbox.post(self._inner, message) @@ -318,23 +189,31 @@ def cast( def new_with_shape(self, shape: Shape) -> "ActorMeshProtocol": return _SingletonActorAdapator(self._inner, self._shape) + def supervision_event(self) -> "Optional[Shared[Exception]]": + return None + + def stop(self) -> "PythonTask[None]": + raise NotImplementedError("stop()") + + def initialized(self) -> "PythonTask[None]": + async def empty(): + pass + + return PythonTask.from_coroutine(empty()) + # standin class for whatever is the serializable python object we use # to name an actor mesh. Hacked up today because ActorMesh # isn't plumbed to non-clients -class _ActorMeshRefImpl(ActorMeshProtocol): +class _ActorMeshRefImpl: def __init__( self, mailbox: Mailbox, - hy_actor_mesh: Optional[PythonActorMesh], + hy_actor_mesh: Optional[PythonActorMeshImpl], proc_mesh: "Optional[ProcMesh]", shape: Shape, actor_ids: List[ActorId], ) -> None: - if not _use_standin_mesh(): - raise ValueError( - "ActorMeshRefImpl should only be used when USE_STANDIN_ACTOR_MESH is set" - ) self._mailbox = mailbox self._actor_mesh = hy_actor_mesh # actor meshes do not have a way to look this up at the moment, @@ -345,14 +224,16 @@ def __init__( @staticmethod def from_hyperactor_mesh( - mailbox: Mailbox, hy_actor_mesh: PythonActorMesh, proc_mesh: "ProcMesh" + mailbox: Mailbox, + shape: Shape, + hy_actor_mesh: PythonActorMeshImpl, + proc_mesh: "ProcMesh", ) -> "_ActorMeshRefImpl": - shape: Shape = hy_actor_mesh.shape return _ActorMeshRefImpl( mailbox, hy_actor_mesh, proc_mesh, - hy_actor_mesh.shape, + shape, [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))], ) @@ -387,7 +268,7 @@ def _check_state(self) -> None: def cast( self, message: PythonMessage, - selection: Selection, + selection: str, mailbox: Mailbox, ) -> None: self._check_state() @@ -454,13 +335,23 @@ def new_with_shape(self, shape: Shape) -> "_ActorMeshRefImpl": def supervision_event(self) -> "Optional[Shared[Exception]]": if self._actor_mesh is None: return None - return self._actor_mesh.supervision_event().spawn() + return self._actor_mesh.supervision_event() - async def stop(self): - await self._actor_mesh.stop() + def stop(self) -> PythonTask[None]: + async def task(): + if self._actor_mesh is not None: + self._actor_mesh.stop() + + return PythonTask.from_coroutine(task()) + + def initialized(self) -> PythonTask[None]: + async def task(): + pass + + return PythonTask.from_coroutine(task()) -class SharedProtocolAdapter(ActorMeshProtocol): +class SharedProtocolAdapter: def __init__(self, inner: "Shared[ActorMeshProtocol]", supervise: bool): self._inner = inner self._supervise = supervise @@ -468,7 +359,7 @@ def __init__(self, inner: "Shared[ActorMeshProtocol]", supervise: bool): def cast( self, message: PythonMessage, - selection: Selection, + selection: str, mailbox: Mailbox, ) -> None: ctx = MonarchContext.get() @@ -505,11 +396,14 @@ async def task(): return PythonTask.from_coroutine(task()).spawn() - async def stop(self) -> None: - await (await self._inner).stop() + def stop(self) -> "PythonTask[None]": + async def task(): + await (await self._inner).stop() + + return PythonTask.from_coroutine(task()) @staticmethod - def _restore(inner: ActorMeshProtocol) -> ActorMeshProtocol: + def _restore(inner: "ActorMeshProtocol") -> "ActorMeshProtocol": return inner def __reduce_ex__(self, protocol): @@ -530,7 +424,7 @@ async def initialized(self): class ActorEndpoint(Endpoint[P, R]): def __init__( self, - actor_mesh: ActorMeshProtocol, + actor_mesh: "ActorMeshProtocol", shape: Shape, proc_mesh: "Optional[ProcMesh]", name: MethodSpecifier, @@ -1059,14 +953,14 @@ class ActorMesh(MeshTrait, Generic[T], DeprecatedNotAFuture): def __init__( self, Class: Type[T], - inner: ActorMeshProtocol, + inner: "ActorMeshProtocol", mailbox: Mailbox, shape: Shape, proc_mesh: "Optional[ProcMesh]", ) -> None: self.__name__: str = Class.__name__ self._class: Type[T] = Class - self._inner: ActorMeshProtocol = inner + self._inner: "ActorMeshProtocol" = inner self._mailbox: Mailbox = mailbox self._shape = shape self._proc_mesh = proc_mesh @@ -1117,7 +1011,7 @@ def _endpoint( def _create( cls, Class: Type[T], - actor_mesh: "PythonTask[PythonActorMesh]", + actor_mesh: "PythonActorMesh | PythonActorMeshImpl", mailbox: Mailbox, shape: Shape, proc_mesh: "ProcMesh", @@ -1126,17 +1020,12 @@ def _create( *args: Any, **kwargs: Any, ) -> "ActorMesh[T]": - async def task(): - if _use_standin_mesh(): - return _ActorMeshRefImpl.from_hyperactor_mesh( - mailbox, await actor_mesh, proc_mesh - ) - else: - return _PythonActorMeshAdapter(await actor_mesh) + if isinstance(actor_mesh, PythonActorMeshImpl): + actor_mesh = _ActorMeshRefImpl.from_hyperactor_mesh( + mailbox, shape, actor_mesh, proc_mesh + ) - shared = PythonTask.from_coroutine(task()).spawn() - inner = SharedProtocolAdapter(shared, True) - mesh = cls(Class, inner, mailbox, shape, proc_mesh) + mesh = cls(Class, actor_mesh, mailbox, shape, proc_mesh) async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None: return None diff --git a/python/monarch/_src/actor/endpoint.py b/python/monarch/_src/actor/endpoint.py index ea2a0892b..3ec7c26c9 100644 --- a/python/monarch/_src/actor/endpoint.py +++ b/python/monarch/_src/actor/endpoint.py @@ -46,7 +46,7 @@ P = ParamSpec("P") R = TypeVar("R") -Selection = Literal["all", "choose"] | int +Selection = Literal["all", "choose"] class Extent: diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 9e55e0dbf..5c1e1147e 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -8,10 +8,13 @@ import asyncio import logging +import os import sys import threading import warnings from contextlib import AbstractContextManager + +from functools import cache from pathlib import Path from typing import ( @@ -123,6 +126,17 @@ async def setup(self) -> None: IN_PAR = False +# A temporary gate used by the PythonActorMesh/PythonActorMeshRef migration. +# We can use this gate to quickly roll back to using _ActorMeshRefImpl, if we +# encounter any issues with the migration. +# +# This should be removed once we confirm PythonActorMesh/PythonActorMeshRef is +# working correctly in production. +@cache +def _use_standin_mesh() -> bool: + return os.getenv("USE_STANDIN_ACTOR_MESH", default="0") != "0" + + class ProcMesh(MeshTrait, DeprecatedNotAFuture): def __init__( self, @@ -325,10 +339,7 @@ def _spawn_nonblocking_on( f"{Class} must subclass monarch.service.Actor to spawn it." ) - async def task() -> "PythonActorMesh": - return await (await pm).spawn_nonblocking(name, _Actor) - - actor_mesh = PythonTask.from_coroutine(task()) + actor_mesh = HyProcMesh.spawn_async(pm, name, _Actor, _use_standin_mesh()) service = ActorMesh._create( Class, actor_mesh, diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index ca23f57b6..2da27cd40 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -277,7 +277,7 @@ def __str__(self): def _cast_call_method_indirect( endpoint: ActorEndpoint, - selection: Selection, + selection: str, client: MeshClient, seq: Seq, args_kwargs_tuple: bytes, @@ -303,7 +303,7 @@ def actor_send( args_kwargs_tuple: bytes, refs: Sequence[Any], port: Optional[Port[Any]], - selection: Selection, + selection: str, ): tensors = [ref for ref in refs if isinstance(ref, Tensor)] # we have some monarch references, we need to ensure their @@ -352,7 +352,7 @@ def _actor_send( args_kwargs_tuple: bytes, refs: Sequence[Any], port: Optional[Port[Any]], - selection: Selection, + selection: str, client: MeshClient, mesh: DeviceMesh, tensors: List[Tensor], diff --git a/python/tests/_monarch/test_actor_mesh.py b/python/tests/_monarch/test_actor_mesh.py index ef4116afc..52f2ccb5b 100644 --- a/python/tests/_monarch/test_actor_mesh.py +++ b/python/tests/_monarch/test_actor_mesh.py @@ -18,10 +18,7 @@ PythonMessage, PythonMessageKind, ) -from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ( - PythonActorMesh, - PythonActorMeshRef, -) +from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension AllocConstraints, @@ -34,7 +31,6 @@ from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox, PortReceiver from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask -from monarch._rust_bindings.monarch_hyperactor.selection import Selection from monarch._rust_bindings.monarch_hyperactor.shape import Shape @@ -99,16 +95,11 @@ async def test_bind_and_pickling() -> None: async def run() -> None: proc_mesh = await allocate() actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) - with pytest.raises(NotImplementedError, match="use bind()"): - pickle.dumps(actor_mesh) + pickle.dumps(actor_mesh) - actor_mesh_ref = actor_mesh.bind() - assert actor_mesh_ref.shape == actor_mesh.shape + actor_mesh_ref = actor_mesh.new_with_shape(proc_mesh.shape) obj = pickle.dumps(actor_mesh_ref) unpickled = pickle.loads(obj) - assert repr(actor_mesh_ref) == repr(unpickled) - assert actor_mesh_ref.shape == unpickled.shape - await proc_mesh.stop_nonblocking() run() @@ -125,25 +116,24 @@ async def spawn_actor_mesh(proc_mesh: ProcMesh) -> PythonActorMesh: PythonMessageKind.CallMethod(MethodSpecifier.Init(), port_ref), pickle.dumps(None), ) - actor_mesh.cast(proc_mesh.client, Selection.all(), message) + actor_mesh.cast(message, "all", proc_mesh.client) # wait for init to complete - for _ in range(len(actor_mesh.shape.ndslice)): + for _ in range(len(proc_mesh.shape.ndslice)): await receiver.recv_task() return actor_mesh async def cast_to_call( - actor_mesh: PythonActorMesh | PythonActorMeshRef, + actor_mesh: PythonActorMesh, mailbox: Mailbox, message: PythonMessage, ) -> None: - sel = Selection.all() - actor_mesh.cast(mailbox, sel, message) + actor_mesh.cast(message, "all", mailbox) async def verify_cast_to_call( - actor_mesh: PythonActorMesh | PythonActorMeshRef, + actor_mesh: PythonActorMesh, mailbox: Mailbox, root_ranks: List[int], ) -> None: @@ -203,7 +193,7 @@ async def test_cast_ref() -> None: async def run() -> None: proc_mesh = await allocate() actor_mesh = await spawn_actor_mesh(proc_mesh) - actor_mesh_ref = actor_mesh.bind() + actor_mesh_ref = actor_mesh.new_with_shape(proc_mesh.shape) await verify_cast_to_call( actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8)) ) @@ -211,93 +201,3 @@ async def run() -> None: await proc_mesh.stop_nonblocking() run() - - -async def verify_slice( - actor_mesh: PythonActorMesh | PythonActorMeshRef, - mailbox: Mailbox, -) -> None: - sliced_mesh = actor_mesh.slice( - gpus=slice(2, 8, 2), - replicas=slice(None, 2), - hosts=slice(3, 7), - ) - sliced_shape = sliced_mesh.shape - # fmt: off - # turn off formatting to make the following list more readable - replica_0_ranks = [ - # gpus=2,4,6 - 24 + 2, 24 + 4, 24 + 6, # hosts=3 - 32 + 2, 32 + 4, 32 + 6, # hosts=4 - 40 + 2, 40 + 4, 40 + 6, # hosts=5 - 48 + 2, 48 + 4, 48 + 6, # hosts=6 - ] - # fmt: on - replica_1_ranks = [rank + 64 for rank in replica_0_ranks] - assert ( - sliced_shape.ranks() == replica_0_ranks + replica_1_ranks - ), f"left is {sliced_shape.ranks()}" - await verify_cast_to_call(sliced_mesh, mailbox, sliced_shape.ranks()) - - assert sliced_shape.labels == ["replicas", "hosts", "gpus"] - assert sliced_shape.ndslice.sizes == [2, 4, 3] - # When slicing a sliced mesh, the user treats this sliced mesh as a - # continuous mesh, and calculates the dimensions based on that assumption, - # without considering the original mesh. - # - # e.g, the following slicing operation selects index 0 and 2 of the hosts - # dimension on the sliced mesh. But corresponding index on the original - # mesh is 3 and 5. - sliced_again = sliced_mesh.slice( - replicas=1, - hosts=slice(None, None, 2), - gpus=slice(1, 3), - ) - again_shape = sliced_again.shape - assert again_shape.labels == ["replicas", "hosts", "gpus"] - assert again_shape.ndslice.sizes == [1, 2, 2] - # fmt: off - # turn off formatting to make the following list more readable - selected_ranks = [ - rank + 64 for rank in - [ - # gpus=4,6 - 24 + 4, 24 + 6, # hosts=3 - 40 + 4, 40 + 6, # hosts=5 - ] - ] - # fmt: on - assert again_shape.ranks() == selected_ranks, f"left is {sliced_shape.ranks()}" - - -# TODO - re-enable after resolving T232206970 -@pytest.mark.oss_skip -@pytest.mark.timeout(30) -async def test_slice_actor_mesh_handle() -> None: - @run_on_tokio - async def run() -> None: - proc_mesh = await allocate() - actor_mesh = await spawn_actor_mesh(proc_mesh) - - await verify_slice(actor_mesh, proc_mesh.client) - - await proc_mesh.stop_nonblocking() - - run() - - -# TODO - re-enable after resolving T232206970 -@pytest.mark.oss_skip -@pytest.mark.timeout(30) -async def test_slice_actor_mesh_ref() -> None: - @run_on_tokio - async def run() -> None: - proc_mesh = await allocate() - actor_mesh = await spawn_actor_mesh(proc_mesh) - - actor_mesh_ref = actor_mesh.bind() - await verify_slice(actor_mesh_ref, proc_mesh.client) - - await proc_mesh.stop_nonblocking() - - run() diff --git a/python/tests/_monarch/test_hyperactor.py b/python/tests/_monarch/test_hyperactor.py index 60655c1e4..a44061d33 100644 --- a/python/tests/_monarch/test_hyperactor.py +++ b/python/tests/_monarch/test_hyperactor.py @@ -97,8 +97,4 @@ async def test_actor_mesh() -> None: proc_mesh = await ProcMesh.allocate_nonblocking(alloc) actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) - assert actor_mesh.get(0) is not None - assert actor_mesh.get(1) is not None - assert actor_mesh.get(2) is None - - assert isinstance(actor_mesh.client, Mailbox) + await actor_mesh.initialized() diff --git a/python/tests/_monarch/test_mailbox.py b/python/tests/_monarch/test_mailbox.py index b6e6e3ee6..0a64938f1 100644 --- a/python/tests/_monarch/test_mailbox.py +++ b/python/tests/_monarch/test_mailbox.py @@ -188,14 +188,14 @@ def my_reduce(state: str, update: str) -> str: port_ref = handle.bind() actor_mesh.cast( - proc_mesh.client, - Selection.from_string("*"), PythonMessage( PythonMessageKind.CallMethod( MethodSpecifier.ReturnsResponse("echo"), port_ref ), pickle.dumps("start"), ), + "all", + proc_mesh.client, ) messge = await receiver.recv_task().with_timeout(seconds=5)