diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 5d6bbb98e..1857a5549 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -53,6 +53,7 @@ use crate::comm::multicast::CastMessage; use crate::comm::multicast::CastMessageEnvelope; use crate::comm::multicast::Uslice; use crate::metrics; +use crate::proc_mesh::ActorEventRouter; use crate::proc_mesh::ProcMesh; use crate::reference::ActorMeshId; use crate::reference::ActorMeshRef; @@ -248,6 +249,7 @@ pub struct RootActorMesh<'a, A: RemoteActor> { // The receiver of supervision events. It is None if it has been transferred to // an actor event observer. actor_supervision_rx: Option>, + actor_event_router: ActorEventRouter, } impl<'a, A: RemoteActor> RootActorMesh<'a, A> { @@ -256,12 +258,14 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { name: String, actor_supervision_rx: mpsc::UnboundedReceiver, ranks: Vec>, + actor_event_router: ActorEventRouter, ) -> Self { Self { proc_mesh: ProcMeshRef::Borrowed(proc_mesh), name, ranks, actor_supervision_rx: Some(actor_supervision_rx), + actor_event_router, } } @@ -270,11 +274,13 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { name: String, actor_supervision_rx: mpsc::UnboundedReceiver, ranks: Vec>, + actor_event_router: ActorEventRouter, ) -> Self { Self { proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)), name, ranks, + actor_event_router, actor_supervision_rx: Some(actor_supervision_rx), } } @@ -294,6 +300,10 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { mesh_id: self.id(), }) } + + pub fn actor_event_router(&self) -> &ActorEventRouter { + &self.actor_event_router + } } /// Supervision event stream for actor mesh. It emits actor supervision events. @@ -315,6 +325,16 @@ impl ActorSupervisionEvents { } result } + + pub fn new( + actor_supervision_rx: mpsc::UnboundedReceiver, + mesh_id: ActorMeshId, + ) -> Self { + Self { + actor_supervision_rx, + mesh_id, + } + } } #[async_trait] diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index add967338..cbf4826b1 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -99,7 +99,33 @@ pub fn global_mailbox() -> Mailbox { .clone() } -type ActorEventRouter = Arc>>; +#[derive(Clone, Debug)] +pub struct ActorEventRouter { + inner: Arc>>>, +} + +impl ActorEventRouter { + pub fn bind(&self, name: ActorMeshName) -> mpsc::UnboundedReceiver { + let (tx, rx) = mpsc::unbounded_channel(); + self.inner.entry(name).or_insert(vec![]).push(tx); + rx + } + + fn new() -> Self { + Self { + inner: Arc::new(DashMap::new()), + } + } +} + +impl Deref for ActorEventRouter { + type Target = DashMap>>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + /// A ProcMesh maintains a mesh of procs whose lifecycles are managed by /// an allocator. pub struct ProcMesh { @@ -335,7 +361,7 @@ impl ProcMesh { alloc: Box::new(alloc), supervision_events, }), - actor_event_router: Arc::new(DashMap::new()), + actor_event_router: ActorEventRouter::new(), shape, ranks: proc_ids .into_iter() @@ -434,16 +460,13 @@ impl ProcMesh { where A::Params: RemoteMessage, { - let (tx, rx) = mpsc::unbounded_channel::(); - { - // Instantiate supervision routing BEFORE spawning the actor mesh. - self.actor_event_router.insert(actor_name.to_string(), tx); - } + let rx = self.actor_event_router.bind(actor_name.to_string()); let root_mesh = RootActorMesh::new( self, actor_name.to_string(), rx, Self::spawn_on_procs::(&self.client, self.agents(), actor_name, params).await?, + self.actor_event_router.clone(), ); Ok(root_mesh) } @@ -515,10 +538,17 @@ impl ProcMesh { } } } + self.actor_event_router.remove(&mesh_name.to_string()); Ok(()) } } +impl Drop for ProcMesh { + fn drop(&mut self) { + self.actor_event_router.clear(); + } +} + /// Proc lifecycle events. #[derive(Debug, Clone)] pub enum ProcEvent { @@ -596,8 +626,10 @@ impl ProcEvents { message_headers: None, caused_by: None, }; - if entry.value().send(event).is_err() { - tracing::warn!("unable to transmit supervision event to actor {}", entry.key()); + for tx in entry.value().iter() { + if tx.send(event.clone()).is_err() { + tracing::warn!("unable to transmit supervision event to actor {}", entry.key()); + } } } @@ -631,9 +663,11 @@ impl ProcEvents { }; // transmit to the correct root actor mesh. { - if let Some(tx) = self.actor_event_router.get(actor_id.name()) { - if tx.send(event).is_err() { - tracing::warn!("unable to transmit supervision event to actor {}", actor_id); + if let Some(txs) = self.actor_event_router.get(actor_id.name()) { + for tx in txs.iter() { + if tx.send(event.clone()).is_err() { + tracing::warn!("unable to transmit supervision event to actor {}", actor_id); + } } } else { tracing::warn!("received supervision event for unregistered actor {}", actor_id); @@ -683,18 +717,16 @@ impl + Send + Sync + 'static> SharedSpawnable for D where A::Params: RemoteMessage, { - let (tx, rx) = mpsc::unbounded_channel::(); - { - // Instantiate supervision routing BEFORE spawning the actor mesh. - self.actor_event_router.insert(actor_name.to_string(), tx); - } + let rx = self.actor_event_router.bind(actor_name.to_string()); let ranks = ProcMesh::spawn_on_procs::(&self.client, self.agents(), actor_name, params).await?; + let actor_event_router = self.actor_event_router.clone(); Ok(RootActorMesh::new_shared( self, actor_name.to_string(), rx, ranks, + actor_event_router, )) } } diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index 41bc82061..25584858c 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -15,6 +15,7 @@ use hyperactor_mesh::Mesh; use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::actor_mesh::ActorSupervisionEvents; +use hyperactor_mesh::proc_mesh::ActorEventRouter; use hyperactor_mesh::reference::ActorMeshRef; use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellRef; @@ -35,6 +36,7 @@ use crate::mailbox::PyMailbox; use crate::proc::PyActorId; use crate::proc_mesh::Keepalive; use crate::pytokio::PyPythonTask; +use crate::runtime::get_tokio_runtime; use crate::selection::PySelection; use crate::shape::PyShape; use crate::supervision::SupervisionError; @@ -61,6 +63,7 @@ impl PythonActorMesh { client: PyMailbox, keepalive: Keepalive, events: ActorSupervisionEvents, + mesh_shape: ndslice::Shape, ) -> Self { let (user_monitor_sender, _) = tokio::sync::broadcast::channel::>(1); @@ -69,6 +72,7 @@ impl PythonActorMesh { events, user_monitor_sender.clone(), Arc::clone(&unhealthy_event), + mesh_shape, )); Self { inner, @@ -86,6 +90,7 @@ impl PythonActorMesh { mut events: ActorSupervisionEvents, user_sender: tokio::sync::broadcast::Sender>, unhealthy_event: Arc>>, + mesh_shape: ndslice::Shape, ) { loop { let event = events.next().await; @@ -93,7 +98,16 @@ impl PythonActorMesh { let mut inner_unhealthy_event = unhealthy_event.lock().unwrap(); match &event { None => *inner_unhealthy_event = Unhealthy::StreamClosed, - Some(event) => *inner_unhealthy_event = Unhealthy::Crashed(event.clone()), + Some(event) => { + // Ignore if the crashed actor is not a part of the mesh. + if mesh_shape + .slice() + .iter() + .any(|index| index == event.actor_id.rank()) + { + *inner_unhealthy_event = Unhealthy::Crashed(event.clone()) + } + } } // Ignore the sender error when there is no receiver, @@ -160,7 +174,27 @@ impl PythonActorMesh { fn bind(&self) -> PyResult { let mesh = self.try_inner()?; - Ok(PythonActorMeshRef { inner: mesh.bind() }) + let unhealthy_event = self.unhealthy_event.clone(); + let actor_event_router = mesh.actor_event_router().clone(); + let monitor = get_tokio_runtime().spawn(Self::actor_mesh_monitor( + ActorSupervisionEvents::new( + actor_event_router.bind(mesh.name().to_string()), + mesh.id(), + ), + self.user_monitor_sender.clone(), + unhealthy_event.clone(), + mesh.shape().clone(), + )); + let mesh_monitor = Some(PythonActorMeshRefMonitor { + user_monitor_sender: self.user_monitor_sender.clone(), + monitor, + unhealthy_event, + actor_event_router, + }); + Ok(PythonActorMeshRef { + inner: mesh.bind(), + mesh_monitor, + }) } fn get_supervision_event(&self) -> PyResult> { @@ -258,6 +292,16 @@ impl PythonActorMesh { } } +#[derive(Debug)] +struct PythonActorMeshRefMonitor { + user_monitor_sender: tokio::sync::broadcast::Sender>, + /// background task listening to stream of supervision events + monitor: tokio::task::JoinHandle<()>, + /// state updated by monitor + unhealthy_event: Arc>>, + actor_event_router: ActorEventRouter, +} + #[pyclass( frozen, name = "PythonActorMeshRef", @@ -266,6 +310,10 @@ impl PythonActorMesh { #[derive(Debug, Serialize, Deserialize)] pub(super) struct PythonActorMeshRef { inner: ActorMeshRef, + #[serde(skip)] + /// Monitors the mesh ref if and only if the mesh ref was created locally + /// We cannot monitor a mesh ref that we have received remotely + mesh_monitor: Option, } #[pymethods] @@ -276,6 +324,12 @@ impl PythonActorMeshRef { selection: &PySelection, message: &PythonMessage, ) -> PyResult<()> { + if let Some(e) = self.get_supervision_event()? { + return Err(SupervisionError::new_err(format!( + "Actor {:?} is unhealthy with reason: {}", + e.actor_id, e.actor_status + ))); + } self.inner .cast(&client.inner, selection.inner().clone(), message.clone()) .map_err(|err| PyException::new_err(err.to_string()))?; @@ -347,8 +401,91 @@ impl PythonActorMeshRef { )) })?; } + let mesh_monitor = match &self.mesh_monitor { + Some(PythonActorMeshRefMonitor { + user_monitor_sender, + actor_event_router, + .. + }) => { + let user_monitor_sender = user_monitor_sender.clone(); + let unhealthy_event = Arc::new(std::sync::Mutex::new(Unhealthy::SoFarSoGood)); + let rx = actor_event_router.bind(self.inner.mesh_id().1.clone()); + let monitor = tokio::spawn(PythonActorMesh::actor_mesh_monitor( + ActorSupervisionEvents::new(rx, self.inner.mesh_id().clone()), + user_monitor_sender.clone(), + unhealthy_event.clone(), + sliced.shape().clone(), + )); + Some(PythonActorMeshRefMonitor { + unhealthy_event, + monitor, + user_monitor_sender, + actor_event_router: actor_event_router.clone(), + }) + } + None => None, + }; - Ok(Self { inner: sliced }) + Ok(Self { + inner: sliced, + mesh_monitor, + }) + } + + fn get_supervision_event(&self) -> PyResult> { + match &self.mesh_monitor { + Some(PythonActorMeshRefMonitor { + unhealthy_event, .. + }) => { + let unhealthy_event = unhealthy_event + .lock() + .expect("failed to acquire unhealthy_event lock"); + + match &*unhealthy_event { + Unhealthy::SoFarSoGood => Ok(None), + Unhealthy::StreamClosed => Ok(Some(PyActorSupervisionEvent { + // Dummy actor as place holder to indicate the whole mesh is stopped + // TODO(albertli): remove this when pushing all supervision logic to rust. + actor_id: id!(default[0].actor[0]).into(), + actor_status: "actor mesh is stopped due to proc mesh shutdown".to_string(), + })), + Unhealthy::Crashed(event) => { + Ok(Some(PyActorSupervisionEvent::from(event.clone()))) + } + } + } + None => Ok(None), + } + } + + fn supervision_event(&self) -> PyResult> { + match &self.mesh_monitor { + Some(PythonActorMeshRefMonitor { + user_monitor_sender, + .. + }) => { + let mut receiver = user_monitor_sender.subscribe(); + + Ok(Some(PyPythonTask::new(async move { + let event = receiver.recv().await; + let event = match event { + Ok(Some(event)) => PyActorSupervisionEvent::from(event.clone()), + Ok(None) | Err(_) => PyActorSupervisionEvent { + // Dummy actor as placeholder to indicate the whole mesh is stopped + // TODO(albertli): remove this when pushing all supervision logic to rust. + actor_id: id!(default[0].actor[0]).into(), + actor_status: "actor mesh is stopped due to proc mesh shutdown" + .to_string(), + }, + }; + Ok(PyErr::new::(format!( + "supervision error: {:?}", + event + ))) + })?)) + } + None => Ok(None), + } } fn new_with_shape(&self, shape: PyShape) -> PyResult { @@ -356,7 +493,56 @@ impl PythonActorMeshRef { .inner .new_with_shape(shape.get_inner().clone()) .map_err(|e| PyErr::new::(e.to_string()))?; - Ok(Self { inner: sliced }) + + let mesh_monitor = match &self.mesh_monitor { + Some(PythonActorMeshRefMonitor { + user_monitor_sender, + actor_event_router, + unhealthy_event, + .. + }) => { + let user_monitor_sender = user_monitor_sender.clone(); + let unhealthy_event = Arc::new(std::sync::Mutex::new( + match &*unhealthy_event.lock().unwrap_or_else(|e| e.into_inner()) { + Unhealthy::SoFarSoGood => Unhealthy::SoFarSoGood, + Unhealthy::Crashed(event) => { + if sliced + .shape() + .slice() + .iter() + .any(|index| index == event.actor_id.rank()) + { + Unhealthy::Crashed(event.clone()) + } else { + Unhealthy::SoFarSoGood + } + } + Unhealthy::StreamClosed => Unhealthy::StreamClosed, + }, + )); + let monitor = get_tokio_runtime().spawn(PythonActorMesh::actor_mesh_monitor( + ActorSupervisionEvents::new( + actor_event_router.bind(self.inner.mesh_id().1.clone()), + self.inner.mesh_id().clone(), + ), + user_monitor_sender.clone(), + unhealthy_event.clone(), + sliced.shape().clone(), + )); + Some(PythonActorMeshRefMonitor { + unhealthy_event, + monitor, + user_monitor_sender, + actor_event_router: actor_event_router.clone(), + }) + } + None => None, + }; + + Ok(Self { + inner: sliced, + mesh_monitor, + }) } #[getter] @@ -380,7 +566,19 @@ impl PythonActorMeshRef { } fn __repr__(&self) -> String { - format!("{:?}", self) + format!( + "PythonActorMeshRef {{ inner: {:?}, mesh_monitor: }}", + self.inner + ) + } +} + +impl Drop for PythonActorMeshRef { + fn drop(&mut self) { + match &self.mesh_monitor { + Some(PythonActorMeshRefMonitor { monitor, .. }) => monitor.abort(), + None => {} + } } } diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index f66d68cc2..78ba762ee 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -284,11 +284,13 @@ impl PyProcMesh { let mailbox = proc_mesh.client().clone(); let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); + let shape = proc_mesh.shape().clone(); Ok(PythonActorMesh::monitored( actor_mesh, PyMailbox { inner: mailbox }, keepalive, actor_events, + shape, )) }) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi index c30f54e08..f0ecc47c3 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi @@ -54,6 +54,19 @@ class PythonActorMeshRef: """ ... + def get_supervision_event(self) -> ActorSupervisionEvent | None: + # TODO: remove this when old casting is removed from python API + """ + Returns supervision event if there is any. + """ + ... + + def supervision_event(self) -> PythonTask[Exception] | None: + """ + Completes with an exception when there is a supervision error. + """ + ... + @final class PythonActorMesh: def bind(self) -> PythonActorMeshRef: diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 33f2dd48b..e9b36c57b 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -285,6 +285,15 @@ def new_with_shape(self, shape: Shape) -> "ActorMeshProtocol": sliced: PythonActorMeshRef = self._inner.new_with_shape(shape) return _PythonActorMeshRefAdapter(sliced) + def supervision_event(self) -> "Optional[Shared[Exception]]": + supervision_event = self._inner.supervision_event() + if supervision_event is None: + return None + return supervision_event.spawn() + + async def stop(self) -> None: + raise NotImplementedError("PythonActorMeshRef.stop() is not supported") + def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: """ Dropping all unpickable states. diff --git a/python/tests/test_actor_error.py b/python/tests/test_actor_error.py index 868e46128..cf8ffee2e 100644 --- a/python/tests/test_actor_error.py +++ b/python/tests/test_actor_error.py @@ -685,3 +685,38 @@ async def test_supervision_with_sending_error(): await actor_mesh.check.call() with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason"): await actor_mesh.check_with_payload.call(payload="a") + + +async def test_slice_supervision() -> None: + pm = await local_proc_mesh(gpus=4) + healthy_mesh = await pm.spawn("healthy", HealthyActor) + error_mesh = await pm.spawn("error", ErrorActor) + slice_1 = error_mesh.slice(gpus=slice(2, 4)) + slice_2 = error_mesh.slice(gpus=2) + slice_3 = error_mesh.slice(gpus=3) + + # Trigger supervision error on gpus=3 + with pytest.raises(SupervisionError, match="supervision error:"): + await slice_3.fail_with_supervision_error.call() + + # Mesh containing all gpus is unhealthy + with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"): + await error_mesh.check.call() + + # Slice containing only gpus=3 is unhealthy + with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"): + await slice_3.check.call() + + # Slice containing gpus=3 is unhealthy + with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"): + await slice_1.check.call() + + # Slice not containing gpus=3 is healthy + check = await slice_2.check.call() + for _, item in check.items(): + assert item == "this is a healthy check" + + # Other actor mesh on the same proc mesh is healthy + check = await healthy_mesh.check.call() + for _, item in check.items(): + assert item == "this is a healthy check"