From 44d63ce0f6dad58875cbecefa87ba729fa6c2bc5 Mon Sep 17 00:00:00 2001 From: Sam Lurye Date: Mon, 24 Nov 2025 11:13:05 -0800 Subject: [PATCH] [monarch] The root client is just a PythonActor This diff makes the root client actor just another `PythonActor`. # Why? Right now the monarch codebase is peppered with special handling to distinguish between normal python actors and the root client "actor", which has type `()` and is actually just a detached `Instance` with no actor loop; it therefore has no message handlers and can't even process supervision events. As a result, we have to wrap the current context's instance in a special `ContextInstance` enum, and everywhere we want to use it, we either have to use the `instance_dispatch!` macro, or insert code that looks like: ```rust match instance { ContextInstance::PythonActor(ins) => { do something }, ContextInstance::Client(ins) => { do something else }, } ``` This makes the code more error-prone and harder to understand, with the added complication that the client handling is often not idiomatic w.r.t hyperactor due to the lack of message handlers/actor loop. Some examples: - [Confusing supervision handling where `owner` might not be defined but `is_owned` is still true and so we need to call into a special `unhandled` function instead of continuing to propagate up the hierarchy](https://fburl.com/code/andy3ggr) - [The root client can't have child actors due to no supervision event handling, so they have to be spawned directly on the root client proc, and even then, there is no way for the supervision event to reach `monarch.actor.unhandled_fault_hook`](https://fburl.com/code/kqd2iwvc) - [The root client handles undeliverable messages via a bespoke tokio task/thread](https://fburl.com/code/jjgfy5d5) Making the root client a normal python actor solves these problems, because: - We don't need a `ContextInstance` enum anymore -- `PyInstance` *always* contains `Instance`. - Supervision events follow a unified path as they bubble up through the hierarchy, and *every* unhandled event reaches `RootClientActor.__supervise__`, defined in python, without special handling. - The root client can handle undeliverable messages using `RootClientActor._handle_undeliverable_message`, defined in python, without special handling. # Navigating the code changes (guide for reviewers) There are a lot of file changes here but only some of them are important. I would recommend reviewing them in the following order: - `monarch/_src/actor/actor_mesh.py` - Defines the `RootClientActor` python class and its behavior. - `hyperactor/src/proc.rs` - Introduces `Proc::actor_instance::(...)`, which returns a detached `A`-typed actor instance/handle, along with its supervision receiver, signal receiver and message receiver. - `monarch_hyperactor/src/actor.rs` - Introduces `PythonActor::bootstrap_client()`, which replaces `global_root_client()` in the root client context. This function starts the root client proc, spawns the `RootClientActor`, starts its actor loop and returns the `Instance`. - The root client actor can now handle `SupervisionFailureMessage` just like every other actor in the hierarchy. - Implements `PythonActor::handle_supervision_event` to pass the event to the actor's `SupervisionFailureMessage` handler. This way, **every unhandled supervision event in the system makes its way to `RootClientActor.__supervise__` eventually**. - `monarch_hyperactor/src/v1/actor_mesh.rs` - Deletes the special handling from the actor states monitor like `is_owned` and the explicit `unhandled_fault_hook` call. If `owner` is defined, it forwards the `SupervisionFailureMessage`, or else it does nothing. - Fixes (what I think was) a bug in `send_state_change`. A supervision event should only be forwarded as `SupervisionFailureMessage` to `owner` if it represents a failure. With the logic before this diff, stopping an actor mesh from inside an actor endpoint would generate a supervision event that reaches `unhandled_fault_hook` and crashes the root process even if it was a healthy stop. - `monarch_hyperactor/src/context.rs` - Deletes `ContextInstance` and replaces it in `PyInstance` with `Instance`. - The rest of the changes are pretty much just cleaning up `instance_dispatch!` calls. Differential Revision: [D87296357](https://our.internmc.facebook.com/intern/diff/D87296357/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D87296357/)! [ghstack-poisoned] --- hyperactor/src/actor.rs | 6 +- hyperactor/src/proc.rs | 69 +++-- hyperactor_mesh/src/lib.rs | 2 +- monarch_extension/src/code_sync.rs | 91 +++--- monarch_extension/src/logging.rs | 13 +- monarch_extension/src/mesh_controller.rs | 27 +- monarch_hyperactor/src/actor.rs | 192 ++++++++++-- monarch_hyperactor/src/actor_mesh.rs | 24 +- .../src/code_sync/conda_sync.rs | 1 - monarch_hyperactor/src/code_sync/manager.rs | 4 +- monarch_hyperactor/src/code_sync/rsync.rs | 5 +- monarch_hyperactor/src/context.rs | 148 ++-------- monarch_hyperactor/src/mailbox.rs | 18 +- monarch_hyperactor/src/proc.rs | 5 +- monarch_hyperactor/src/proc_mesh.rs | 6 - monarch_hyperactor/src/supervision.rs | 20 +- monarch_hyperactor/src/v1/actor_mesh.rs | 273 ++++-------------- monarch_hyperactor/src/v1/host_mesh.rs | 24 +- monarch_hyperactor/src/v1/logging.rs | 61 ++-- monarch_hyperactor/src/v1/proc_mesh.rs | 37 ++- monarch_rdma/extension/lib.rs | 106 +++---- python/monarch/_src/actor/actor_mesh.py | 62 +++- python/tests/test_python_actors.py | 1 + 23 files changed, 557 insertions(+), 638 deletions(-) diff --git a/hyperactor/src/actor.rs b/hyperactor/src/actor.rs index c574cdfdb..590503c22 100644 --- a/hyperactor/src/actor.rs +++ b/hyperactor/src/actor.rs @@ -312,8 +312,10 @@ where /// with the ID of the actor being served. #[derive(Debug)] pub struct ActorError { - pub(crate) actor_id: Box, - pub(crate) kind: Box, + /// The ActorId for the actor that generated this error. + pub actor_id: Box, + /// The kind of error that occurred. + pub kind: Box, } /// The kinds of actor serving errors. diff --git a/hyperactor/src/proc.rs b/hyperactor/src/proc.rs index 093cf5476..b71d653d9 100644 --- a/hyperactor/src/proc.rs +++ b/hyperactor/src/proc.rs @@ -411,7 +411,9 @@ impl Proc { .map_err(|existing| anyhow::anyhow!("coordinator port is already set to {existing}")) } - fn handle_supervision_event(&self, event: ActorSupervisionEvent) { + /// Handle a supervision event received by the proc. Attempt to forward it to the + /// supervision coordinator port if one is set, otherwise crash the process. + pub fn handle_supervision_event(&self, event: ActorSupervisionEvent) { let result = match self.state().supervision_coordinator_port.get() { Some(port) => port.send(event.clone()).map_err(anyhow::Error::from), None => Err(anyhow::anyhow!( @@ -530,26 +532,46 @@ impl Proc { Ok(instance.start(actor, actor_loop_receivers.take().unwrap(), work_rx)) } - /// Create and return an actor instance and its corresponding handle. This allows actors to be - /// "inverted": the caller can use the returned [`Instance`] to send and receive messages, - /// launch child actors, etc. The actor itself does not handle any messages, and supervision events - /// are always forwarded to the proc. Otherwise the instance acts as a normal actor, and can be - /// referenced and stopped. + /// Wrapper for [`Proc::actor_instance::<()>`]. pub fn instance(&self, name: &str) -> Result<(Instance<()>, ActorHandle<()>), anyhow::Error> { + let (instance, handle, ..) = self.actor_instance(name)?; + + Ok((instance, handle)) + } + + /// Create and return an actor instance, its corresponding handle, its signal port receiver, + /// its supervision port receiver, and its message receiver. This allows actors to be + /// "inverted": the caller can use the returned [`Instance`] to send and receive messages, + /// launch child actors, etc. The actor itself does not handle any messages unless driven by + /// the caller. Otherwise the instance acts as a normal actor, and can be referenced and + /// stopped. + pub fn actor_instance( + &self, + name: &str, + ) -> Result< + ( + Instance, + ActorHandle, + PortReceiver, + PortReceiver, + mpsc::UnboundedReceiver>, + ), + anyhow::Error, + > { let actor_id = self.allocate_root_id(name)?; - let _ = tracing::debug_span!( + let span = tracing::debug_span!( "actor_instance", actor_name = name, - actor_type = std::any::type_name::<()>(), + actor_type = std::any::type_name::(), actor_id = actor_id.to_string(), ); - - let (instance, _, _) = Instance::new(self.clone(), actor_id.clone(), true, None); + let _guard = span.enter(); + let (instance, actor_loop_receivers, work_rx) = + Instance::new(self.clone(), actor_id.clone(), false, None); + let (signal_rx, supervision_rx) = actor_loop_receivers.unwrap(); let handle = ActorHandle::new(instance.inner.cell.clone(), instance.inner.ports.clone()); - instance.change_status(ActorStatus::Client); - - Ok((instance, handle)) + Ok((instance, handle, supervision_rx, signal_rx, work_rx)) } /// Create a child instance. Called from `Instance`. @@ -874,11 +896,11 @@ impl MailboxSender for WeakProc { /// Represents a single work item used by the instance to dispatch to /// actor handles. Specifically, this enables handler polymorphism. -struct WorkCell( +pub struct WorkCell( Box< dyn for<'a> FnOnce( &'a mut A, - &'a mut Instance, + &'a Instance, ) -> Pin> + 'a + Send>> + Send @@ -891,7 +913,7 @@ impl WorkCell { fn new( f: impl for<'a> FnOnce( &'a mut A, - &'a mut Instance, + &'a Instance, ) -> Pin> + 'a + Send>> + Send @@ -902,10 +924,10 @@ impl WorkCell { } /// Handle the message represented by this work cell. - fn handle<'a>( + pub fn handle<'a>( self, actor: &'a mut A, - instance: &'a mut Instance, + instance: &'a Instance, ) -> Pin> + Send + 'a>> { (self.0)(actor, instance) } @@ -1451,7 +1473,8 @@ impl Instance { Ok(()) } - async fn handle_supervision_event( + /// Handle a supervision event using the provided actor. + pub async fn handle_supervision_event( &self, actor: &mut A, supervision_event: ActorSupervisionEvent, @@ -1483,7 +1506,7 @@ impl Instance { #[hyperactor::instrument(fields(actor_id = self.self_id().to_string(), actor_name = self.self_id().name()))] async unsafe fn handle_message( - &mut self, + &self, actor: &mut A, type_info: Option<&'static TypeInfo>, headers: Attrs, @@ -1519,8 +1542,8 @@ impl Instance { actor.handle(&context, message).await } - /// Spawn on child on this instance. Currently used only by cap::CanSpawn. - pub(crate) fn spawn(&self, actor: C) -> anyhow::Result> { + /// Spawn on child on this instance. + pub fn spawn(&self, actor: C) -> anyhow::Result> { self.inner.proc.spawn_child(self.inner.cell.clone(), actor) } @@ -2041,7 +2064,7 @@ impl Ports { let port = self.mailbox.open_enqueue_port(move |headers, msg: M| { let seq_info = headers.get(SEQ_INFO).cloned(); - let work = WorkCell::new(move |actor: &mut A, instance: &mut Instance| { + let work = WorkCell::new(move |actor: &mut A, instance: &Instance| { Box::pin(async move { // SAFETY: we guarantee that the passed type_info is for type M. unsafe { diff --git a/hyperactor_mesh/src/lib.rs b/hyperactor_mesh/src/lib.rs index 160b6f1fb..0be83983b 100644 --- a/hyperactor_mesh/src/lib.rs +++ b/hyperactor_mesh/src/lib.rs @@ -29,7 +29,7 @@ mod metrics; pub mod proc_mesh; pub mod reference; pub mod resource; -mod router; +pub mod router; pub mod shared_cell; pub mod shortuuid; #[cfg(target_os = "linux")] diff --git a/monarch_extension/src/code_sync.rs b/monarch_extension/src/code_sync.rs index bd49534d3..060152764 100644 --- a/monarch_extension/src/code_sync.rs +++ b/monarch_extension/src/code_sync.rs @@ -18,6 +18,7 @@ use hyperactor::context; use hyperactor_mesh::Mesh; use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::shared_cell::SharedCell; +use monarch_hyperactor; use monarch_hyperactor::code_sync::WorkspaceLocation; use monarch_hyperactor::code_sync::manager::CodeSyncManager; use monarch_hyperactor::code_sync::manager::CodeSyncManagerParams; @@ -27,8 +28,6 @@ use monarch_hyperactor::code_sync::manager::WorkspaceConfig; use monarch_hyperactor::code_sync::manager::WorkspaceShape; use monarch_hyperactor::code_sync::manager::code_sync_mesh; use monarch_hyperactor::context::PyInstance; -use monarch_hyperactor::instance_dispatch; -use monarch_hyperactor::instance_into_dispatch; use monarch_hyperactor::proc_mesh::PyProcMesh; use monarch_hyperactor::runtime::signal_safe_block_on; use monarch_hyperactor::v1::proc_mesh::PyProcMesh as PyProcMeshV1; @@ -279,32 +278,34 @@ impl CodeSyncMeshClient { if let Ok(v0) = proc_mesh.downcast::() { let proc_mesh = v0.borrow().try_inner()?; signal_safe_block_on(py, async move { - let actor_mesh = instance_dispatch!(client, |cx| { - proc_mesh - .spawn(cx, "code_sync_manager", &CodeSyncManagerParams {}) - .await? - }); + let actor_mesh = proc_mesh + .spawn( + client.deref(), + "code_sync_manager", + &CodeSyncManagerParams {}, + ) + .await?; Ok(Self { actor_mesh }) })? } else { let proc_mesh = proc_mesh.downcast::()?.borrow().mesh_ref()?; signal_safe_block_on(py, async move { - let actor_mesh = instance_dispatch!(client, |cx| { - proc_mesh - .spawn_service(cx, "code_sync_manager", &CodeSyncManagerParams {}) - .await - .map_err(|err| PyException::new_err(err.to_string()))? - }); - instance_dispatch!(client, |cx| { - actor_mesh - .cast( - cx, - SetActorMeshMessage { - actor_mesh: actor_mesh.deref().clone(), - }, - ) - .map_err(|err| PyException::new_err(err.to_string()))? - }); + let actor_mesh = proc_mesh + .spawn_service( + client.deref(), + "code_sync_manager", + &CodeSyncManagerParams {}, + ) + .await + .map_err(|err| PyException::new_err(err.to_string()))?; + actor_mesh + .cast( + client.deref(), + SetActorMeshMessage { + actor_mesh: actor_mesh.deref().clone(), + }, + ) + .map_err(|err| PyException::new_err(err.to_string()))?; Ok(Self { actor_mesh: SharedCell::from(RootActorMesh::from(actor_mesh)), }) @@ -324,19 +325,17 @@ impl CodeSyncMeshClient { ) -> PyResult> { let instance = instance.clone(); let actor_mesh = self.actor_mesh.clone(); - instance_into_dispatch!(instance, |cx| { - monarch_hyperactor::runtime::future_into_py(py, async move { - CodeSyncMeshClient::sync_workspace_( - &cx, - actor_mesh, - local, - remote, - method.into(), - auto_reload, - ) - .err_into() - .await - }) + monarch_hyperactor::runtime::future_into_py(py, async move { + CodeSyncMeshClient::sync_workspace_( + instance.deref(), + actor_mesh, + local, + remote, + method.into(), + auto_reload, + ) + .err_into() + .await }) } @@ -354,17 +353,15 @@ impl CodeSyncMeshClient { py, async move { for workspace in workspaces.into_iter() { - instance_dispatch!(instance, async |cx| { - CodeSyncMeshClient::sync_workspace_( - cx, - actor_mesh.clone(), - workspace.local, - workspace.remote, - workspace.method.into(), - auto_reload, - ) - .await - })?; + CodeSyncMeshClient::sync_workspace_( + instance.deref(), + actor_mesh.clone(), + workspace.local, + workspace.remote, + workspace.method.into(), + auto_reload, + ) + .await?; } anyhow::Ok(()) } diff --git a/monarch_extension/src/logging.rs b/monarch_extension/src/logging.rs index 9cd5ca51c..3c7e6270d 100644 --- a/monarch_extension/src/logging.rs +++ b/monarch_extension/src/logging.rs @@ -8,6 +8,7 @@ #![allow(unsafe_op_in_unsafe_fn)] +use std::ops::Deref; use std::time::Duration; use hyperactor::ActorHandle; @@ -22,7 +23,6 @@ use hyperactor_mesh::logging::LogForwardMessage; use hyperactor_mesh::selection::Selection; use hyperactor_mesh::shared_cell::SharedCell; use monarch_hyperactor::context::PyInstance; -use monarch_hyperactor::instance_dispatch; use monarch_hyperactor::logging::LoggerRuntimeActor; use monarch_hyperactor::logging::LoggerRuntimeMessage; use monarch_hyperactor::proc_mesh::PyProcMesh; @@ -94,13 +94,10 @@ impl LoggingMeshClient { .client_proc() .spawn("log_client", LogClientActor::default())?; let client_actor_ref = client_actor.bind(); - let forwarder_mesh = instance_dispatch!(instance, |cx| { - proc_mesh - .spawn(cx, "log_forwarder", &client_actor_ref) - .await? - }); - let logger_mesh = - instance_dispatch!(instance, |cx| { proc_mesh.spawn(cx, "logger", &()).await? }); + let forwarder_mesh = proc_mesh + .spawn(instance.deref(), "log_forwarder", &client_actor_ref) + .await?; + let logger_mesh = proc_mesh.spawn(instance.deref(), "logger", &()).await?; // Register flush_internal as a on-stop callback let client_actor_for_callback = client_actor.clone(); diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 952defac8..49208542a 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -46,7 +46,6 @@ use monarch_hyperactor::actor::PythonMessage; use monarch_hyperactor::actor::PythonMessageKind; use monarch_hyperactor::buffers::FrozenBuffer; use monarch_hyperactor::context::PyInstance; -use monarch_hyperactor::instance_dispatch; use monarch_hyperactor::local_state_broker::LocalStateBrokerActor; use monarch_hyperactor::mailbox::PyPortId; use monarch_hyperactor::ndslice::PySlice; @@ -140,17 +139,14 @@ impl _Controller { let id = NEXT_ID.fetch_add(1, atomic::Ordering::Relaxed); let controller_handle: Arc>> = signal_safe_block_on(py, async move { - let controller_handle = instance_dispatch!(client, |instance| { - instance.proc().spawn( - &Name::new("mesh_controller").to_string(), - MeshControllerActor::new(MeshControllerActorParams { - proc_mesh, - id, - rank_map, - }) - .await, - )? - }); + let controller_handle = client.spawn( + MeshControllerActor::new(MeshControllerActorParams { + proc_mesh, + id, + rank_map, + }) + .await, + )?; Ok::<_, anyhow::Error>(Arc::new(Mutex::new(controller_handle))) })??; @@ -231,8 +227,7 @@ impl _Controller { } fn _drain_and_stop(&mut self, py: Python<'_>, instance: &PyInstance) -> PyResult<()> { - let (stop_worker_port, stop_worker_receiver) = - instance_dispatch!(instance, |cx_instance| { cx_instance.open_once_port() }); + let (stop_worker_port, stop_worker_receiver) = instance.open_once_port(); self.controller_handle .blocking_lock() @@ -817,6 +812,10 @@ impl Actor for MeshControllerActor { self.brokers = Some(brokers); Ok(()) } + + fn display_name(&self) -> Option { + Some(format!("mesh_controller_{}", self.id)) + } } impl Debug for MeshControllerActor { diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index bc6aa34b3..12f6c38ff 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -21,12 +21,15 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortHandle; use hyperactor::PortHandle; +use hyperactor::Proc; use hyperactor::ProcId; use hyperactor::RemoteSpawn; use hyperactor::actor::ActorError; use hyperactor::actor::ActorErrorKind; use hyperactor::actor::ActorStatus; use hyperactor::attrs::Attrs; +use hyperactor::channel::ChannelAddr; +use hyperactor::mailbox::BoxableMailboxSender; use hyperactor::mailbox::MessageEnvelope; use hyperactor::mailbox::Undeliverable; use hyperactor::message::Bind; @@ -37,7 +40,9 @@ use hyperactor::supervision::ActorSupervisionEvent; use hyperactor_mesh::actor_mesh::CAST_ACTOR_MESH_ID; use hyperactor_mesh::comm::multicast::CAST_ORIGINATING_SENDER; use hyperactor_mesh::comm::multicast::CastInfo; +use hyperactor_mesh::proc_mesh::default_transport; use hyperactor_mesh::reference::ActorMeshId; +use hyperactor_mesh::router; use monarch_types::PickledPyObject; use monarch_types::SerializablePyErr; use pyo3::IntoPyObjectExt; @@ -75,6 +80,7 @@ use crate::proc::PyActorId; use crate::proc::PyProc; use crate::proc::PySerialized; use crate::pytokio::PythonTask; +use crate::runtime::get_tokio_runtime; use crate::runtime::signal_safe_block_on; use crate::supervision::MeshFailure; use crate::supervision::SupervisionFailureMessage; @@ -503,6 +509,23 @@ pub struct PythonActor { } impl PythonActor { + pub(crate) fn new(actor_type: PickledPyObject) -> Result { + Ok(Python::with_gil(|py| -> Result { + let unpickled = actor_type.unpickle(py)?; + let class_type: &Bound<'_, PyType> = unpickled.downcast()?; + let actor: PyObject = class_type.call0()?.into_py_any(py)?; + + // Only create per-actor TaskLocals if not using shared runtime + let task_locals = (!hyperactor::config::global::get(SHARED_ASYNCIO_RUNTIME)) + .then(|| Python::allow_threads(py, create_task_locals)); + Ok(Self { + actor, + task_locals, + instance: None, + }) + })?) + } + /// Get the TaskLocals to use for this actor. /// Returns either the shared TaskLocals or this actor's own TaskLocals based on configuration. fn get_task_locals(&self, py: Python) -> &pyo3_async_runtimes::TaskLocals { @@ -512,6 +535,126 @@ impl PythonActor { Python::allow_threads(py, || SHARED_TASK_LOCALS.get_or_init(create_task_locals)) }) } + + pub(crate) fn bootstrap_client(py: Python<'_>) -> (&'static Instance, ActorHandle) { + static ROOT_CLIENT_INSTANCE: OnceLock> = OnceLock::new(); + + let client_proc = Proc::direct_with_default( + ChannelAddr::any(default_transport()), + "mesh_root_client_proc".into(), + router::global().clone().boxed(), + ) + .unwrap(); + + // Make this proc reachable through the global router, so that we can use the + // same client in both direct-addressed and ranked-addressed modes. + router::global().bind(client_proc.proc_id().clone().into(), client_proc.clone()); + + let actor_mesh_mod = py + .import("monarch._src.actor.actor_mesh") + .expect("import actor_mesh"); + let root_client_class = actor_mesh_mod + .getattr("RootClientActor") + .expect("get RootClientActor"); + + let mut actor = PythonActor::new( + PickledPyObject::pickle(&actor_mesh_mod.getattr("_Actor").expect("get _Actor")) + .expect("pickle _Actor"), + ) + .expect("create client PythonActor"); + + let (client, handle, supervision_rx, signal_rx, work_rx) = client_proc + .actor_instance( + root_client_class + .getattr("name") + .expect("get RootClientActor.name") + .extract() + .expect("extract RootClientActor.name"), + ) + .expect("root instance create"); + + ROOT_CLIENT_INSTANCE + .set(client) + .map_err(|_| "already initialized root client instance") + .unwrap(); + + handle + .send( + PythonMessage::new( + PythonMessageKind::CallMethod { + name: MethodSpecifier::Init {}, + response_port: None, + }, + root_client_class + .call_method0("_pickled_init_args") + .expect("call RootClientActor._pickled_init_args"), + ) + .expect("create RootClientActor init message"), + ) + .expect("initialize root client"); + + let instance = ROOT_CLIENT_INSTANCE.get().unwrap(); + + get_tokio_runtime().spawn(async move { + let mut signal_rx = signal_rx; + let mut supervision_rx = supervision_rx; + let mut work_rx = work_rx; + let err = 'messages: loop { + tokio::select! { + work = work_rx.recv() => { + let work = work.expect("inconsistent work queue state"); + if let Err(err) = work.handle(&mut actor, instance).await { + for supervision_event in supervision_rx.drain() { + if let Err(err) = instance.handle_supervision_event(&mut actor, supervision_event).await { + break 'messages err; + } + } + let kind = ActorErrorKind::processing(err); + break ActorError { + actor_id: Box::new(instance.self_id().clone()), + kind: Box::new(kind), + }; + } + } + _ = signal_rx.recv() => { + // TODO: do we need any signal handling for the root client? + } + Ok(supervision_event) = supervision_rx.recv() => { + if let Err(err) = instance.handle_supervision_event(&mut actor, supervision_event).await { + break err; + } + } + }; + }; + let event = match *err.kind { + ActorErrorKind::UnhandledSupervisionEvent(event) => *event, + _ => { + let error_kind = ActorErrorKind::Generic(err.kind.to_string()); + let status = ActorStatus::Failed(error_kind); + ActorSupervisionEvent::new( + instance.self_id().clone(), + actor.display_name(), + status, + None, + ) + } + }; + instance.proc().handle_supervision_event(event); + }); + + (ROOT_CLIENT_INSTANCE.get().unwrap(), handle) + } +} + +pub(crate) fn root_client_actor() -> &'static Instance { + static ROOT_CLIENT_ACTOR: OnceLock<&'static Instance> = OnceLock::new(); + + ROOT_CLIENT_ACTOR.get_or_init(|| { + Python::with_gil(|py| { + let (client, _handle) = PythonActor::bootstrap_client(py); + client + }) + }) } /// An undeliverable might have its sender address set as the comm actor instead @@ -689,6 +832,24 @@ impl Actor for PythonActor { Ok(()) } } + + async fn handle_supervision_event( + &mut self, + this: &Instance, + event: &ActorSupervisionEvent, + ) -> Result { + let cx = Context::new(this, Attrs::new()); + self.handle( + &cx, + SupervisionFailureMessage { + actor_mesh_name: None, + rank: None, + event: event.clone(), + }, + ) + .await + .map(|_| true) + } } #[async_trait] @@ -696,20 +857,7 @@ impl RemoteSpawn for PythonActor { type Params = PickledPyObject; async fn new(actor_type: PickledPyObject) -> Result { - Ok(Python::with_gil(|py| -> Result { - let unpickled = actor_type.unpickle(py)?; - let class_type: &Bound<'_, PyType> = unpickled.downcast()?; - let actor: PyObject = class_type.call0()?.into_py_any(py)?; - - // Only create per-actor TaskLocals if not using shared runtime - let task_locals = (!hyperactor::config::global::get(SHARED_ASYNCIO_RUNTIME)) - .then(|| Python::allow_threads(py, create_task_locals)); - Ok(Self { - actor, - task_locals, - instance: None, - }) - })?) + Self::new(actor_type) } } @@ -725,6 +873,8 @@ fn create_task_locals() -> pyo3_async_runtimes::TaskLocals { let kwargs = PyDict::new(py); let target = event_loop.getattr("run_forever").unwrap(); kwargs.set_item("target", target).unwrap(); + // Need to make this a daemon thread, otherwise shutdown will hang. + kwargs.set_item("daemon", true).unwrap(); let thread = py .import("threading") .unwrap() @@ -923,7 +1073,10 @@ impl Handler for PythonActor { // this actor is now the event creator. for (actor_name, status) in [ ( - message.actor_mesh_name.as_str(), + message + .actor_mesh_name + .as_deref() + .unwrap_or_else(|| message.event.actor_id.name()), "SupervisionError::Unhandled", ), (cx.self_id().name(), "UnhandledSupervisionEvent"), @@ -942,7 +1095,7 @@ impl Handler for PythonActor { cx.self_id().clone(), self.display_name(), ActorStatus::Failed(ActorErrorKind::UnhandledSupervisionEvent( - Box::new(message.event), + Box::new(message.event.clone()), )), None, ), @@ -958,7 +1111,10 @@ impl Handler for PythonActor { // Add to caused_by chain. for (actor_name, status) in [ ( - message.actor_mesh_name.as_str(), + message + .actor_mesh_name + .as_deref() + .unwrap_or_else(|| message.event.actor_id.name()), "SupervisionError::__supervise__::exception", ), (cx.self_id().name(), "UnhandledSupervisionEvent"), @@ -978,7 +1134,7 @@ impl Handler for PythonActor { self.display_name(), ActorStatus::Failed(ActorErrorKind::ErrorDuringHandlingSupervision( err.to_string(), - Box::new(message.event), + Box::new(message.event.clone()), )), None, ), diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index 5d53eff39..7a507225d 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -7,6 +7,7 @@ */ use std::future::Future; +use std::ops::Deref; use std::pin::Pin; use std::sync::Arc; use std::sync::Weak; @@ -43,7 +44,6 @@ use crate::actor::PythonActor; use crate::actor::PythonMessage; use crate::actor::PythonMessageKind; use crate::context::PyInstance; -use crate::instance_dispatch; use crate::mailbox::EitherPortRef; use crate::mailbox::PyMailbox; use crate::proc::PyActorId; @@ -306,11 +306,9 @@ impl ActorMeshProtocol for PythonActorMeshImpl { } } - instance_dispatch!(instance, |cx_instance| { - self.try_inner()? - .cast(cx_instance, selection, message) - .map_err(|err| PyException::new_err(err.to_string()))?; - }); + self.try_inner()? + .cast(instance.deref(), selection, message) + .map_err(|err| PyException::new_err(err.to_string()))?; Ok(()) } @@ -351,10 +349,8 @@ impl ActorMeshProtocol for PythonActorMeshImpl { .take() .await .map_err(|_| PyRuntimeError::new_err("`ActorMesh` has already been stopped"))?; - instance_dispatch!(instance, |cx_instance| { - actor_mesh.stop(cx_instance).await.map_err(|err| { - PyException::new_err(format!("Failed to stop actor mesh: {}", err)) - }) + actor_mesh.stop(instance.deref()).await.map_err(|err| { + PyException::new_err(format!("Failed to stop actor mesh: {}", err)) })?; Ok(()) }) @@ -485,11 +481,9 @@ impl ActorMeshProtocol for PythonActorMeshRef { } } - instance_dispatch!(instance, |cx_instance| { - self.inner - .cast(cx_instance, selection, message) - .map_err(|err| PyException::new_err(err.to_string()))?; - }); + self.inner + .cast(instance.deref(), selection, message) + .map_err(|err| PyException::new_err(err.to_string()))?; Ok(()) } diff --git a/monarch_hyperactor/src/code_sync/conda_sync.rs b/monarch_hyperactor/src/code_sync/conda_sync.rs index 517f27716..bb761ba53 100644 --- a/monarch_hyperactor/src/code_sync/conda_sync.rs +++ b/monarch_hyperactor/src/code_sync/conda_sync.rs @@ -19,7 +19,6 @@ use hyperactor::Bind; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; -use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::connect::Connect; diff --git a/monarch_hyperactor/src/code_sync/manager.rs b/monarch_hyperactor/src/code_sync/manager.rs index 424ef92d9..ee0539288 100644 --- a/monarch_hyperactor/src/code_sync/manager.rs +++ b/monarch_hyperactor/src/code_sync/manager.rs @@ -558,13 +558,13 @@ mod tests { use hyperactor_mesh::alloc::Allocator; use hyperactor_mesh::alloc::local::LocalAllocator; use hyperactor_mesh::proc_mesh::ProcMesh; - use hyperactor_mesh::proc_mesh::global_root_client; use ndslice::extent; use ndslice::shape; use tempfile::TempDir; use tokio::fs; use super::*; + use crate::actor::root_client_actor; #[test] fn test_workspace_shape_owners() { @@ -663,7 +663,7 @@ mod tests { // TODO: thread through context, or access the actual python context; // for now this is basically equivalent (arguably better) to using the proc mesh client. - let instance = global_root_client(); + let instance = root_client_actor(); // Spawn actor mesh with CodeSyncManager actors let actor_mesh = proc_mesh diff --git a/monarch_hyperactor/src/code_sync/rsync.rs b/monarch_hyperactor/src/code_sync/rsync.rs index 82f19c935..ee2203dd2 100644 --- a/monarch_hyperactor/src/code_sync/rsync.rs +++ b/monarch_hyperactor/src/code_sync/rsync.rs @@ -28,7 +28,6 @@ use hyperactor::Bind; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; -use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; @@ -462,13 +461,13 @@ mod tests { use hyperactor_mesh::alloc::Allocator; use hyperactor_mesh::alloc::local::LocalAllocator; use hyperactor_mesh::proc_mesh::ProcMesh; - use hyperactor_mesh::proc_mesh::global_root_client; use ndslice::extent; use tempfile::TempDir; use tokio::fs; use tokio::net::TcpListener; use super::*; + use crate::actor::root_client_actor; #[tokio::test] // TODO: OSS: Cannot assign requested address (os error 99) @@ -520,7 +519,7 @@ mod tests { // TODO: thread through context, or access the actual python context; // for now this is basically equivalent (arguably better) to using the proc mesh client. - let instance = global_root_client(); + let instance = root_client_actor(); // Spawn actor mesh with RsyncActors let actor_mesh = proc_mesh diff --git a/monarch_hyperactor/src/context.rs b/monarch_hyperactor/src/context.rs index aaee0f04b..74196b8f2 100644 --- a/monarch_hyperactor/src/context.rs +++ b/monarch_hyperactor/src/context.rs @@ -6,104 +6,23 @@ * LICENSE file in the root directory of this source tree. */ +use hyperactor::Instance; +use hyperactor::context; use hyperactor_mesh::comm::multicast::CastInfo; -use hyperactor_mesh::proc_mesh::global_root_client; use ndslice::Extent; use ndslice::Point; use pyo3::prelude::*; use crate::actor::PythonActor; +use crate::actor::root_client_actor; use crate::mailbox::PyMailbox; use crate::proc::PyActorId; use crate::runtime; use crate::shape::PyPoint; -pub enum ContextInstance { - Client(hyperactor::Instance<()>), - PythonActor(hyperactor::Instance), -} - -impl ContextInstance { - fn mailbox_for_py(&self) -> &hyperactor::Mailbox { - match self { - ContextInstance::Client(ins) => ins.mailbox_for_py(), - ContextInstance::PythonActor(ins) => ins.mailbox_for_py(), - } - } - - fn self_id(&self) -> &hyperactor::ActorId { - match self { - ContextInstance::Client(ins) => ins.self_id(), - ContextInstance::PythonActor(ins) => ins.self_id(), - } - } -} - -impl Clone for ContextInstance { - fn clone(&self) -> Self { - match self { - ContextInstance::Client(ins) => ContextInstance::Client(ins.clone_for_py()), - ContextInstance::PythonActor(ins) => ContextInstance::PythonActor(ins.clone_for_py()), - } - } -} - -#[macro_export] -macro_rules! instance_dispatch { - ($ins:expr, |$cx:ident| $code:block) => { - match $ins.context_instance() { - $crate::context::ContextInstance::Client($cx) => $code, - $crate::context::ContextInstance::PythonActor($cx) => $code, - } - }; - ($ins:expr, |$cx:ident| $code:block) => { - match $ins.into_context_instance() { - $crate::context::ContextInstance::Client($cx) => $code, - $crate::context::ContextInstance::PythonActor($cx) => $code, - } - }; - ($ins:expr, async |$cx:ident| $code:block) => { - match $ins.context_instance() { - $crate::context::ContextInstance::Client($cx) => async $code.await, - $crate::context::ContextInstance::PythonActor($cx) => async $code.await, - } - }; - ($ins:expr, async move |$cx:ident| $code:block) => { - match $ins.context_instance() { - $crate::context::ContextInstance::Client($cx) => async move $code.await, - $crate::context::ContextInstance::PythonActor($cx) => async move $code.await, - } - }; -} - -/// Similar to `instance_dispatch!`, but moves the PyInstance into an Instance -/// instead of a borrow. -#[macro_export] -macro_rules! instance_into_dispatch { - ($ins:expr, |$cx:ident| $code:block) => { - match $ins.into_context_instance() { - $crate::context::ContextInstance::Client($cx) => $code, - $crate::context::ContextInstance::PythonActor($cx) => $code, - } - }; - ($ins:expr, async |$cx:ident| $code:block) => { - match $ins.into_context_instance() { - $crate::context::ContextInstance::Client($cx) => async $code.await, - $crate::context::ContextInstance::PythonActor($cx) => async $code.await, - } - }; - ($ins:expr, async move |$cx:ident| $code:block) => { - match $ins.into_context_instance() { - $crate::context::ContextInstance::Client($cx) => async move $code.await, - $crate::context::ContextInstance::PythonActor($cx) => async move $code.await, - } - }; -} - -#[derive(Clone)] #[pyclass(name = "Instance", module = "monarch._src.actor.actor_mesh")] pub struct PyInstance { - inner: ContextInstance, + inner: Instance, #[pyo3(get, set)] proc_mesh: Option, #[pyo3(get, set, name = "_controller_controller")] @@ -121,6 +40,29 @@ pub struct PyInstance { creator: Option, } +impl Clone for PyInstance { + fn clone(&self) -> Self { + PyInstance { + inner: self.inner.clone_for_py(), + proc_mesh: self.proc_mesh.clone(), + controller_controller: self.controller_controller.clone(), + rank: self.rank.clone(), + children: self.children.clone(), + name: self.name.clone(), + class_name: self.class_name.clone(), + creator: self.creator.clone(), + } + } +} + +impl std::ops::Deref for PyInstance { + type Target = Instance; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + #[pymethods] impl PyInstance { #[getter] @@ -137,43 +79,15 @@ impl PyInstance { } impl PyInstance { - pub fn context_instance(&self) -> &ContextInstance { - &self.inner - } - - pub fn into_context_instance(self) -> ContextInstance { + pub fn into_instance(self) -> Instance { self.inner } } -impl From<&hyperactor::Instance> for ContextInstance { - fn from(ins: &hyperactor::Instance) -> Self { - ContextInstance::PythonActor(ins.clone_for_py()) - } -} - -impl From<&hyperactor::Instance<()>> for ContextInstance { - fn from(ins: &hyperactor::Instance<()>) -> Self { - ContextInstance::Client(ins.clone_for_py()) - } -} - -impl From<&hyperactor::Context<'_, PythonActor>> for ContextInstance { - fn from(cx: &hyperactor::Context<'_, PythonActor>) -> Self { - ContextInstance::PythonActor(cx.clone_for_py()) - } -} - -impl From<&hyperactor::Context<'_, ()>> for ContextInstance { - fn from(cx: &hyperactor::Context<'_, ()>) -> Self { - ContextInstance::Client(cx.clone_for_py()) - } -} - -impl> From for PyInstance { +impl> From for PyInstance { fn from(ins: I) -> Self { PyInstance { - inner: ins.into(), + inner: ins.instance().clone_for_py(), proc_mesh: None, controller_controller: None, rank: PyPoint::new(0, Extent::unity().into()), @@ -206,7 +120,7 @@ impl PyContext { #[staticmethod] fn _root_client_context(py: Python<'_>) -> PyResult { let _guard = runtime::get_tokio_runtime().enter(); - let instance: PyInstance = global_root_client().into(); + let instance: PyInstance = root_client_actor().into(); Ok(PyContext { instance: instance.into_pyobject(py)?.into(), rank: Extent::unity().point_of_rank(0).unwrap(), diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 8ce578482..3bf975fe6 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -9,6 +9,7 @@ use std::hash::DefaultHasher; use std::hash::Hash; use std::hash::Hasher; +use std::ops::Deref; use std::sync::Arc; use hyperactor::Mailbox; @@ -48,7 +49,6 @@ use serde::Serialize; use crate::actor::PythonMessage; use crate::actor::PythonMessageKind; use crate::context::PyInstance; -use crate::instance_dispatch; use crate::proc::PyActorId; use crate::pytokio::PyPythonTask; use crate::pytokio::PythonTask; @@ -289,11 +289,9 @@ impl PythonPortRef { } fn send(&self, instance: &PyInstance, message: PythonMessage) -> PyResult<()> { - instance_dispatch!(instance, |cx_instance| { - self.inner - .send(cx_instance, message) - .map_err(|err| PyErr::new::(format!("Port closed: {}", err)))?; - }); + self.inner + .send(instance.deref(), message) + .map_err(|err| PyErr::new::(format!("Port closed: {}", err)))?; Ok(()) } @@ -467,11 +465,9 @@ impl PythonOncePortRef { return Err(PyErr::new::("OncePortRef is already used")); }; - instance_dispatch!(instance, |cx_instance| { - port_ref - .send(cx_instance, message) - .map_err(|err| PyErr::new::(format!("Port closed: {}", err)))?; - }); + port_ref + .send(instance.deref(), message) + .map_err(|err| PyErr::new::(format!("Port closed: {}", err)))?; Ok(()) } diff --git a/monarch_hyperactor/src/proc.rs b/monarch_hyperactor/src/proc.rs index 542fbc50b..d9a6f382d 100644 --- a/monarch_hyperactor/src/proc.rs +++ b/monarch_hyperactor/src/proc.rs @@ -26,7 +26,6 @@ use std::time::SystemTime; use anyhow::Result; use hyperactor::ActorRef; use hyperactor::RemoteMessage; -use hyperactor::RemoteSpawn; use hyperactor::actor::Signal; use hyperactor::channel; use hyperactor::channel::ChannelAddr; @@ -149,7 +148,7 @@ impl PyProc { Ok(PythonActorHandle { inner: proc.spawn( name.as_deref().unwrap_or("anon"), - PythonActor::new(pickled_type).await?, + PythonActor::new(pickled_type)?, )?, }) }) @@ -168,7 +167,7 @@ impl PyProc { inner: signal_safe_block_on(py, async move { proc.spawn( name.as_deref().unwrap_or("anon"), - PythonActor::new(pickled_type).await?, + PythonActor::new(pickled_type)?, ) }) .map_err(|e| PyRuntimeError::new_err(e.to_string()))??, diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 5198ee2c9..aaef3b444 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -44,7 +44,6 @@ use crate::actor_mesh::ActorMeshProtocol; use crate::actor_mesh::PythonActorMesh; use crate::actor_mesh::PythonActorMeshImpl; use crate::alloc::PyAlloc; -use crate::context::PyInstance; use crate::mailbox::PyMailbox; use crate::pytokio::PyPythonTask; use crate::pytokio::PyShared; @@ -400,11 +399,6 @@ impl PyProcMesh { .into()) } - #[getter] - fn client(&self) -> PyResult { - Ok(self.try_inner()?.client().into()) - } - fn __repr__(&self) -> PyResult { Ok(format!("", *self.try_inner()?)) } diff --git a/monarch_hyperactor/src/supervision.rs b/monarch_hyperactor/src/supervision.rs index 5703ea79d..dbfdffe8a 100644 --- a/monarch_hyperactor/src/supervision.rs +++ b/monarch_hyperactor/src/supervision.rs @@ -24,8 +24,8 @@ create_exception!( #[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Bind, Unbind)] pub struct SupervisionFailureMessage { - pub actor_mesh_name: String, - pub rank: usize, + pub actor_mesh_name: Option, + pub rank: Option, pub event: ActorSupervisionEvent, } @@ -36,15 +36,19 @@ pub struct SupervisionFailureMessage { module = "monarch._rust_bindings.monarch_hyperactor.supervision" )] pub struct MeshFailure { - pub mesh_name: String, - pub rank: usize, + pub mesh_name: Option, + pub rank: Option, pub event: ActorSupervisionEvent, } impl MeshFailure { - pub fn new(mesh_name: &impl ToString, rank: usize, event: ActorSupervisionEvent) -> Self { + pub fn new( + mesh_name: Option<&impl ToString>, + rank: Option, + event: ActorSupervisionEvent, + ) -> Self { Self { - mesh_name: mesh_name.to_string(), + mesh_name: mesh_name.map(|name| name.to_string()), rank, event, } @@ -66,7 +70,9 @@ impl std::fmt::Display for MeshFailure { write!( f, "MeshFailure(mesh_name={}, rank={}, event={})", - self.mesh_name, self.rank, self.event + self.mesh_name.clone().unwrap_or("".into()), + self.rank.map_or("".into(), |r| r.to_string()), + self.event ) } } diff --git a/monarch_hyperactor/src/v1/actor_mesh.rs b/monarch_hyperactor/src/v1/actor_mesh.rs index 0efeddbe5..8b04b4a9e 100644 --- a/monarch_hyperactor/src/v1/actor_mesh.rs +++ b/monarch_hyperactor/src/v1/actor_mesh.rs @@ -8,6 +8,7 @@ use std::clone::Clone; use std::collections::HashMap; +use std::ops::Deref; use std::sync::Arc; use std::sync::Mutex; @@ -42,7 +43,6 @@ use pyo3::IntoPyObjectExt; use pyo3::exceptions::PyException; use pyo3::exceptions::PyNotImplementedError; use pyo3::exceptions::PyRuntimeError; -use pyo3::exceptions::PySystemExit; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyBytes; @@ -54,15 +54,12 @@ use crate::actor::PythonMessage; use crate::actor_mesh::ActorMeshProtocol; use crate::actor_mesh::PyActorSupervisionEvent; use crate::actor_mesh::PythonActorMesh; -use crate::context::ContextInstance; use crate::context::PyInstance; -use crate::instance_dispatch; use crate::proc::PyActorId; use crate::pytokio::PyPythonTask; use crate::pytokio::PyShared; use crate::runtime::get_tokio_runtime; use crate::shape::PyRegion; -use crate::supervision::MeshFailure; use crate::supervision::SupervisionError; use crate::supervision::SupervisionFailureMessage; use crate::supervision::Unhealthy; @@ -178,15 +175,11 @@ impl PythonActorMeshImpl { } } - fn make_monitor( + fn make_monitor( &self, instance: PyInstance, - unhandled: F, supervision_display_name: String, - ) -> SupervisionMonitor - where - F: Fn(MeshFailure) + Send + 'static, - { + ) -> SupervisionMonitor { match self { // Owned meshes send a local message to themselves for the failures. PythonActorMeshImpl::Owned(inner) => Self::create_monitor( @@ -194,7 +187,6 @@ impl PythonActorMeshImpl { (*inner.mesh).clone(), inner.health_state.clone(), true, - unhandled, supervision_display_name, ), // Ref meshes send no message, they are only used to generate @@ -204,7 +196,6 @@ impl PythonActorMeshImpl { inner.mesh.clone(), inner.health_state.clone(), false, - unhandled, supervision_display_name, ), } @@ -213,106 +204,28 @@ impl PythonActorMeshImpl { /// Get a supervision receiver for this mesh. The passed in monitor object /// must outlive the returned receiver, or else the sender may be dropped /// and the receiver will get a closed channel. - fn supervision_receiver( + fn supervision_receiver( &self, instance: &PyInstance, monitor: &Arc>>, - unhandled: F, supervision_display_name: Option, - ) -> watch::Receiver> - where - F: Fn(MeshFailure) + Send + 'static, - { + ) -> watch::Receiver> { let mut guard = monitor.lock().unwrap(); guard.get_or_insert_with(move || { let instance = Python::with_gil(|_py| instance.clone()); - self.make_monitor( - instance, - unhandled, - supervision_display_name.unwrap_or_default(), - ) + self.make_monitor(instance, supervision_display_name.unwrap_or_default()) }); let monitor = guard.as_ref().unwrap(); monitor.receiver.clone() } - fn unhandled_fault_hook<'py>(py: Python<'py>) -> PyResult> { - py.import("monarch.actor")?.getattr("unhandled_fault_hook") - } - - fn get_unhandled(&self, instance: &PyInstance) -> Box { - let is_client = matches!(instance.context_instance(), ContextInstance::Client(_)); - match self { - PythonActorMeshImpl::Owned(_) => { - if is_client { - Box::new(move |failure| { - Python::with_gil(|py| { - let unhandled = Self::unhandled_fault_hook(py) - .expect("failed to fetch unhandled_fault_hook"); - let pyfailure = failure - .clone() - .into_pyobject(py) - .expect("failed to turn PyErr into PyObject"); - let result = unhandled.call1((pyfailure,)); - // Handle SystemExit and actually exit the process. - // It is normally just an exception. - if let Err(e) = result { - if e.is_instance_of::(py) { - let code = e - .into_bound_py_any(py) - .unwrap() - .getattr("code") - .unwrap() - .extract::() - .unwrap(); - tracing::error!( - "unhandled event reached unhandled_fault_hook: {}, which is exiting the process with code {}", - failure, - code - ); - std::process::exit(code); - } else { - // The callback raised some other exception, and there's - // no way to handle it. Just exit the process anyways - tracing::error!( - "unhandled event reached unhandled_fault_hook: {}, which raised an exception: {:?}. \ - Exiting the process with code 1", - failure, - e, - ); - std::process::exit(1); - } - } else { - tracing::warn!( - "unhandled event reached unhandled_fault_hook: {}, but that function produced no exception or crash. Ignoring the error", - failure - ); - } - }); - }) - } else { - Box::new(|_| { - // Never called if not client. - }) - } - } - PythonActorMeshImpl::Ref(_inner) => Box::new(|_| { - // Never called if not owned. - }), - } - } - - fn create_monitor( + fn create_monitor( instance: PyInstance, mesh: ActorMeshRef, health_state: Arc, is_owned: bool, - unhandled: F, supervision_display_name: String, - ) -> SupervisionMonitor - where - F: Fn(MeshFailure) + Send + 'static, - { + ) -> SupervisionMonitor { // There's a shared monitor for all whole mesh ref. Note that slices do // not share the health state. This is fine because requerying a slice // of a mesh will still return any failed state. @@ -323,47 +236,24 @@ impl PythonActorMeshImpl { // 3 seconds is chosen to not penalize short-lived successful calls, // while still able to catch issues before they look like a hang or timeout. let time_between_checks = tokio::time::Duration::from_secs(3); - match instance.context_instance() { - ContextInstance::Client(cx_instance) => { - actor_states_monitor( - cx_instance, - mesh, - // owner is always None if the owning instance is a client, - // because nothing can handle the message. - None, - is_owned, - unhandled, - health_state, - time_between_checks, - sender, - canceled, - supervision_display_name.clone(), - ) - .await; - } - ContextInstance::PythonActor(cx_instance) => { - // Only make a handle if is_owned is true, we don't want - // to send messages to the ref holder. - let owner = if is_owned { - Some(cx_instance.handle()) - } else { - None - }; - actor_states_monitor( - cx_instance, - mesh, - owner, - is_owned, - unhandled, - health_state, - time_between_checks, - sender, - canceled, - supervision_display_name, - ) - .await; - } + // Only make a handle if is_owned is true, we don't want + // to send messages to the ref holder. + let owner = if is_owned { + Some(instance.handle()) + } else { + None }; + actor_states_monitor( + instance.deref(), + mesh, + owner, + health_state, + time_between_checks, + sender, + canceled, + supervision_display_name, + ) + .await; }); SupervisionMonitor { cancel, receiver } } @@ -403,20 +293,14 @@ fn actor_state_to_supervision_events( (rank, events) } -fn send_state_change( +fn send_state_change( rank: usize, event: ActorSupervisionEvent, actor_mesh_name: &Name, owner: &Option>, - is_owned: bool, - is_proc_stopped: bool, - unhandled: &F, health_state: &Arc, sender: &watch::Sender>, -) where - F: Fn(MeshFailure), -{ - let failure = MeshFailure::new(actor_mesh_name, rank, event.clone()); +) { // Any supervision event that is not a failure should not generate // call "unhandled". // This includes the Stopped status, which is a state that occurs when the @@ -440,37 +324,37 @@ fn send_state_change( ); } - // Send a notification to the owning actor of this mesh, if there is one. + // Send a notification to the owning actor of this mesh, if there is one, but only if + // the supervision event is a failure. if let Some(owner) = owner { - if let Err(error) = owner.send(SupervisionFailureMessage { - actor_mesh_name: actor_mesh_name.to_string(), - rank, - event: event.clone(), - }) { - tracing::warn!( - name = "ActorMeshStatus", - status = "SupervisionError", - %event, - %error, - "failed to send supervision event to owner {}: {}. dropping event", - owner.actor_id(), - error - ); + if is_failed { + if let Err(error) = owner.send(SupervisionFailureMessage { + actor_mesh_name: Some(actor_mesh_name.to_string()), + rank: Some(rank), + event: event.clone(), + }) { + tracing::warn!( + name = "ActorMeshStatus", + status = "SupervisionError", + %event, + %error, + "failed to send supervision event to owner {}: {}. dropping event", + owner.actor_id(), + error + ); + } } - } else if is_owned && is_failed { - // The mesh has an owner, but it is not a PythonActor, so it must be the client. - // Call the unhandled function to let the client control what to do. - unhandled(failure); } + let mut inner_unhealthy_event = health_state .unhealthy_event .lock() .expect("unhealthy_event lock poisoned"); health_state.crashed_ranks.insert(rank, event.clone()); - *inner_unhealthy_event = if is_proc_stopped { - Unhealthy::StreamClosed - } else { + *inner_unhealthy_event = if is_failed { Unhealthy::Crashed(event.clone()) + } else { + Unhealthy::StreamClosed }; let event_actor_id = event.actor_id.clone(); let py_event = PyActorSupervisionEvent::from(event.clone()); @@ -493,21 +377,16 @@ fn send_state_change( /// created rank is the original rank of the actor on the mesh, not the rank after /// slicing. /// -/// * is_owned is true if this monitor is running on the owning instance. When true, -/// a message will be sent to "owner" if it is not None. If owner is None, -/// then a panic will be raised instead to crash the client. /// * time_between_tasks 1trols how frequently to poll. #[hyperactor::instrument_infallible(fields( host_mesh=actor_mesh.proc_mesh().host_mesh_name().map(|n| n.to_string()), proc_mesh=actor_mesh.proc_mesh().name().to_string(), actor_name=actor_mesh.name().to_string(), ))] -async fn actor_states_monitor( +async fn actor_states_monitor( cx: &impl context::Actor, actor_mesh: ActorMeshRef, owner: Option>, - is_owned: bool, - unhandled: F, health_state: Arc, time_between_checks: tokio::time::Duration, sender: watch::Sender>, @@ -516,7 +395,6 @@ async fn actor_states_monitor( ) where A: Actor + RemoteSpawn + Referable, A::Params: RemoteMessage, - F: Fn(MeshFailure), { // This implementation polls every "time_between_checks" duration, checking // for changes in the actor states. It can be improved in two ways: @@ -546,9 +424,6 @@ async fn actor_states_monitor( ), actor_mesh.name(), &owner, - is_owned, - false, - &unhandled, &health_state, &sender, ); @@ -602,9 +477,6 @@ async fn actor_states_monitor( ), actor_mesh.name(), &owner, - is_owned, - true, - &unhandled, &health_state, &sender, ); @@ -628,9 +500,6 @@ async fn actor_states_monitor( ), actor_mesh.name(), &owner, - is_owned, - false, - &unhandled, &health_state, &sender, ); @@ -656,9 +525,6 @@ async fn actor_states_monitor( events[0].clone(), actor_mesh.name(), &owner, - is_owned, - false, - &unhandled, &health_state, &sender, ); @@ -680,9 +546,6 @@ async fn actor_states_monitor( events[0].clone(), actor_mesh.name(), &owner, - is_owned, - false, - &unhandled, &health_state, &sender, ); @@ -744,9 +607,8 @@ impl ActorMeshProtocol for PythonActorMeshImpl { fn supervision_event(&self, instance: &PyInstance) -> PyResult> { // Make a clone so each endpoint can get the same supervision events. - let unhandled = self.get_unhandled(instance); let monitor = self.monitor().clone(); - let mut receiver = self.supervision_receiver(instance, &monitor, unhandled, None); + let mut receiver = self.supervision_receiver(instance, &monitor, None); PyPythonTask::new(async move { receiver.changed().await.map_err(|e| { PyValueError::new_err(format!("Waiting for supervision event change: {}", e)) @@ -776,14 +638,8 @@ impl ActorMeshProtocol for PythonActorMeshImpl { supervision_display_name: String, ) -> PyResult<()> { // Fetch the receiver once, this will initialize the monitor task. - let unhandled = self.get_unhandled(instance); let monitor = self.monitor().clone(); - self.supervision_receiver( - instance, - &monitor, - unhandled, - Some(supervision_display_name), - ); + self.supervision_receiver(instance, &monitor, Some(supervision_display_name)); Ok(()) } @@ -799,13 +655,10 @@ impl ActorMeshProtocol for PythonActorMeshImpl { let (slf, instance) = Python::with_gil(|_py| (self.clone(), instance.clone())); match slf { PythonActorMeshImpl::Owned(mesh) => PyPythonTask::new(async move { - instance_dispatch!(instance, |cx_instance| { - mesh.mesh - .stop(cx_instance) - .await - .map_err(|err| PyValueError::new_err(err.to_string()))? - }); - Ok(()) + mesh.mesh + .stop(instance.deref()) + .await + .map_err(|err| PyValueError::new_err(err.to_string())) }), PythonActorMeshImpl::Ref(_) => Err(PyErr::new::( "Cannot call stop on an ActorMeshRef, requires an owned ActorMesh", @@ -826,10 +679,8 @@ impl ActorMeshProtocol for ActorMeshRef { instance: &PyInstance, ) -> PyResult<()> { if structurally_equal(&selection, &Selection::All(Box::new(Selection::True))) { - instance_dispatch!(instance, |cx_instance| { - self.cast(cx_instance, message.clone()) - .map_err(|err| PyException::new_err(err.to_string()))?; - }); + self.cast(instance.deref(), message.clone()) + .map_err(|err| PyException::new_err(err.to_string()))?; } else if structurally_equal(&selection, &Selection::Any(Box::new(Selection::True))) { let region = Ranked::region(self); let random_rank = fastrand::usize(0..region.num_ranks()); @@ -841,11 +692,9 @@ impl ActorMeshProtocol for ActorMeshRef { Vec::new(), Slice::new(offset, Vec::new(), Vec::new()).map_err(anyhow::Error::from)?, ); - instance_dispatch!(instance, |cx_instance| { - self.sliced(singleton_region) - .cast(cx_instance, message.clone()) - .map_err(|err| PyException::new_err(err.to_string()))?; - }); + self.sliced(singleton_region) + .cast(instance.deref(), message.clone()) + .map_err(|err| PyException::new_err(err.to_string()))?; } else { return Err(PyRuntimeError::new_err(format!( "invalid selection: {:?}", diff --git a/monarch_hyperactor/src/v1/host_mesh.rs b/monarch_hyperactor/src/v1/host_mesh.rs index c73c1ffcb..ee660c82b 100644 --- a/monarch_hyperactor/src/v1/host_mesh.rs +++ b/monarch_hyperactor/src/v1/host_mesh.rs @@ -7,6 +7,7 @@ */ use std::collections::HashMap; +use std::ops::Deref; use std::path::PathBuf; use hyperactor_mesh::bootstrap::BootstrapCommand; @@ -26,7 +27,6 @@ use pyo3::types::PyType; use crate::actor::to_py_error; use crate::alloc::PyAlloc; use crate::context::PyInstance; -use crate::instance_dispatch; use crate::pytokio::PyPythonTask; use crate::shape::PyExtent; use crate::shape::PyRegion; @@ -141,10 +141,9 @@ impl PyHostMesh { }; let instance = instance.clone(); PyPythonTask::new(async move { - let mesh = instance_dispatch!(instance, async move |cx_instance| { - HostMesh::allocate(cx_instance, alloc, &name, bootstrap_params).await - }) - .map_err(|err| PyException::new_err(err.to_string()))?; + let mesh = HostMesh::allocate(instance.deref(), alloc, &name, bootstrap_params) + .await + .map_err(|err| PyException::new_err(err.to_string()))?; Ok(Self::new_owned(mesh)) }) } @@ -159,10 +158,10 @@ impl PyHostMesh { let instance = instance.clone(); let per_host = per_host.clone().into(); let mesh_impl = async move { - let proc_mesh = instance_dispatch!(instance, async move |cx_instance| { - host_mesh.spawn(cx_instance, &name, per_host).await - }) - .map_err(to_py_error)?; + let proc_mesh = host_mesh + .spawn(instance.deref(), &name, per_host) + .await + .map_err(to_py_error)?; Ok(PyProcMesh::new_owned(proc_mesh)) }; PyPythonTask::new(mesh_impl) @@ -198,12 +197,7 @@ impl PyHostMesh { PyHostMesh::Owned(inner) => { let instance = instance.clone(); let mesh_borrow = inner.0.borrow().map_err(anyhow::Error::from)?; - let fut = async move { - instance_dispatch!(instance, |cx_instance| { - mesh_borrow.shutdown(cx_instance).await - })?; - Ok(()) - }; + let fut = async move { Ok(mesh_borrow.shutdown(instance.deref()).await?) }; PyPythonTask::new(fut) } PyHostMesh::Ref(_) => Err(PyRuntimeError::new_err( diff --git a/monarch_hyperactor/src/v1/logging.rs b/monarch_hyperactor/src/v1/logging.rs index b47906b49..d9986300e 100644 --- a/monarch_hyperactor/src/v1/logging.rs +++ b/monarch_hyperactor/src/v1/logging.rs @@ -18,7 +18,6 @@ use hyperactor_mesh::logging::LogClientMessage; use hyperactor_mesh::logging::LogForwardActor; use hyperactor_mesh::logging::LogForwardMessage; use hyperactor_mesh::v1::ActorMesh; -use hyperactor_mesh::v1::Name; use hyperactor_mesh::v1::actor_mesh::ActorMeshRef; use ndslice::View; use pyo3::Bound; @@ -27,7 +26,6 @@ use pyo3::types::PyModule; use pyo3::types::PyString; use crate::context::PyInstance; -use crate::instance_dispatch; use crate::logging::LoggerRuntimeActor; use crate::logging::LoggerRuntimeMessage; use crate::pytokio::PyPythonTask; @@ -202,12 +200,7 @@ impl LoggingMeshClient { // 1. Spawn the client-side coordinator actor (lives in // the caller's process). let client_actor: ActorHandle = - instance_dispatch!(instance, async move |cx_instance| { - cx_instance.proc().spawn( - &Name::new("log_client").to_string(), - LogClientActor::default(), - ) - })?; + instance.spawn(LogClientActor::default())?; let client_actor_ref = client_actor.bind(); // Read config to decide if we stand up per-proc @@ -218,12 +211,10 @@ impl LoggingMeshClient { // (stdout/stderr forwarders). let forwarder_mesh = if forwarding_enabled { // Spawn a `LogFwdActor` on every proc. - let mesh = instance_dispatch!(instance, async |cx_instance| { - proc_mesh - .spawn(cx_instance, "log_forwarder", &client_actor_ref) - .await - }) - .map_err(anyhow::Error::from)?; + let mesh = proc_mesh + .spawn(instance.deref(), "log_forwarder", &client_actor_ref) + .await + .map_err(anyhow::Error::from)?; Some(mesh) } else { @@ -231,10 +222,10 @@ impl LoggingMeshClient { }; // 3. Always spawn a `LoggerRuntimeActor` on every proc. - let logger_mesh = instance_dispatch!(instance, async |cx_instance| { - proc_mesh.spawn(cx_instance, "logger", &()).await - }) - .map_err(anyhow::Error::from)?; + let logger_mesh = proc_mesh + .spawn(instance.deref(), "logger", &()) + .await + .map_err(anyhow::Error::from)?; Ok(Self { forwarder_mesh, @@ -293,10 +284,14 @@ impl LoggingMeshClient { // Forwarders exist (config enabled at startup). We can // toggle live. (Some(fwd_mesh), _) => { - instance_dispatch!(instance, |cx_instance| { - fwd_mesh.cast(cx_instance, LogForwardMessage::SetMode { stream_to_client }) - }) - .map_err(|e| PyErr::new::(e.to_string()))?; + fwd_mesh + .cast( + instance.deref(), + LogForwardMessage::SetMode { stream_to_client }, + ) + .map_err(|e| { + PyErr::new::(e.to_string()) + })?; } // Forwarders were never spawned (global forwarding @@ -318,11 +313,9 @@ impl LoggingMeshClient { } // Always update the per-proc Python logging level. - instance_dispatch!(instance, |cx_instance| { - self.logger_mesh - .cast(cx_instance, LoggerRuntimeMessage::SetLogging { level }) - }) - .map_err(|e| PyErr::new::(e.to_string()))?; + self.logger_mesh + .cast(instance.deref(), LoggerRuntimeMessage::SetLogging { level }) + .map_err(|e| PyErr::new::(e.to_string()))?; // Always update the client actor's aggregation window. self.client_actor @@ -356,10 +349,9 @@ impl LoggingMeshClient { return Ok(()); }; - instance_dispatch!(instance, async move |cx_instance| { - Self::flush_internal(cx_instance, client_actor, forwarder_mesh).await - }) - .map_err(|e| PyErr::new::(e.to_string())) + Self::flush_internal(instance.deref(), client_actor, forwarder_mesh) + .await + .map_err(|e| PyErr::new::(e.to_string())) }) } } @@ -479,20 +471,21 @@ mod tests { use ndslice::View; // .region(), .num_ranks() etc. use super::*; + use crate::actor::PythonActor; use crate::pytokio::AwaitPyExt; use crate::pytokio::ensure_python; /// Bring up a minimal "world" suitable for integration-style /// tests. - pub async fn test_world() -> Result<(Proc, Instance<()>, HostMesh, ProcMesh)> { + pub async fn test_world() -> Result<(Proc, Instance, HostMesh, ProcMesh)> { ensure_python(); let proc = Proc::direct(ChannelTransport::Unix.any(), "root".to_string()) .await .expect("failed to start root Proc"); - let (instance, _handle) = proc - .instance("client") + let (instance, ..) = proc + .actor_instance("client") .expect("failed to create proc Instance"); let host_mesh = HostMesh::local_with_bootstrap( diff --git a/monarch_hyperactor/src/v1/proc_mesh.rs b/monarch_hyperactor/src/v1/proc_mesh.rs index 296447f19..89bf7f3e0 100644 --- a/monarch_hyperactor/src/v1/proc_mesh.rs +++ b/monarch_hyperactor/src/v1/proc_mesh.rs @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +use std::ops::Deref; + use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::v1::proc_mesh::ProcMesh; use hyperactor_mesh::v1::proc_mesh::ProcMeshRef; @@ -26,7 +28,6 @@ use crate::actor_mesh::ActorMeshProtocol; use crate::actor_mesh::PythonActorMesh; use crate::alloc::PyAlloc; use crate::context::PyInstance; -use crate::instance_dispatch; use crate::pytokio::PyPythonTask; use crate::pytokio::PyShared; use crate::runtime::get_tokio_runtime; @@ -83,10 +84,9 @@ impl PyProcMesh { }; let instance = instance.clone(); PyPythonTask::new(async move { - let mesh = instance_dispatch!(instance, async move |cx_instance| { - ProcMesh::allocate(cx_instance, alloc, &name).await - }) - .map_err(|err| PyException::new_err(err.to_string()))?; + let mesh = ProcMesh::allocate(instance.deref(), alloc, &name) + .await + .map_err(|err| PyException::new_err(err.to_string()))?; Ok(Self::new_owned(mesh)) }) } @@ -101,10 +101,10 @@ impl PyProcMesh { let proc_mesh = self.mesh_ref()?.clone(); let instance = instance.clone(); let mesh_impl = async move { - let actor_mesh = instance_dispatch!(instance, async move |cx_instance| { - proc_mesh.spawn(cx_instance, &name, &pickled_type).await - }) - .map_err(to_py_error)?; + let actor_mesh = proc_mesh + .spawn(instance.deref(), &name, &pickled_type) + .await + .map_err(to_py_error)?; Ok(PythonActorMesh::from_impl(Box::new( PythonActorMeshImpl::new_owned(actor_mesh), ))) @@ -131,10 +131,10 @@ impl PyProcMesh { Ok((slf.mesh_ref()?.clone(), pickled_type)) })?; - let actor_mesh = instance_dispatch!(instance, async move |cx_instance| { - proc_mesh.spawn(cx_instance, &name, &pickled_type).await - }) - .map_err(anyhow::Error::from)?; + let actor_mesh = proc_mesh + .spawn(instance.deref(), &name, &pickled_type) + .await + .map_err(anyhow::Error::from)?; Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh))) }; if emulated { @@ -200,13 +200,10 @@ impl PyProcMesh { PyPythonTask::new(async move { let mesh = owned_inner.0.take().await; match mesh { - Ok(mut mesh) => { - instance_dispatch!(instance, async move |cx_instance| { - mesh.stop(cx_instance).await.map_err(|e| { - PyValueError::new_err(format!("error stopping mesh: {}", e)) - }) - }) - } + Ok(mut mesh) => mesh + .stop(instance.deref()) + .await + .map_err(|e| PyValueError::new_err(format!("error stopping mesh: {}", e))), Err(e) => { // Don't return an exception, silently ignore the stop request // because it was already done. diff --git a/monarch_rdma/extension/lib.rs b/monarch_rdma/extension/lib.rs index 5fa6d510b..a350b0b62 100644 --- a/monarch_rdma/extension/lib.rs +++ b/monarch_rdma/extension/lib.rs @@ -7,6 +7,8 @@ */ #![allow(unsafe_op_in_unsafe_fn)] +use std::ops::Deref; + use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::Named; @@ -14,7 +16,6 @@ use hyperactor::ProcId; use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::shared_cell::SharedCell; use monarch_hyperactor::context::PyInstance; -use monarch_hyperactor::instance_dispatch; use monarch_hyperactor::proc_mesh::PyProcMesh; use monarch_hyperactor::pytokio::PyPythonTask; use monarch_hyperactor::runtime::signal_safe_block_on; @@ -63,11 +64,9 @@ async fn create_rdma_buffer( let owner_ref: ActorRef = ActorRef::attest(owner_id); // Create the RdmaBuffer - let buffer = instance_dispatch!(client, |cx_instance| { - owner_ref - .request_buffer_deprecated(&cx_instance, addr, size) - .await? - }); + let buffer = owner_ref + .request_buffer_deprecated(client.deref(), addr, size) + .await?; Ok(PyRdmaBuffer { buffer, owner_ref }) } @@ -150,24 +149,16 @@ impl PyRdmaBuffer { ) -> PyResult { let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id); PyPythonTask::new(async move { - let local_buffer = instance_dispatch!(client, |cx_instance| { - local_owner_ref - .request_buffer_deprecated(cx_instance, addr, size) - .await? - }); - instance_dispatch!(client, |cx_instance| { - local_buffer - .write_from(cx_instance, buffer, timeout) - .await - .map_err(|e| { - PyException::new_err(format!("failed to read into buffer: {}", e)) - })? - }); - instance_dispatch!(client, |cx_instance| { - local_owner_ref - .release_buffer_deprecated(cx_instance, local_buffer) - .await? - }); + let local_buffer = local_owner_ref + .request_buffer_deprecated(client.deref(), addr, size) + .await?; + local_buffer + .write_from(client.deref(), buffer, timeout) + .await + .map_err(|e| PyException::new_err(format!("failed to read into buffer: {}", e)))?; + local_owner_ref + .release_buffer_deprecated(client.deref(), local_buffer) + .await?; Ok(()) }) } @@ -196,24 +187,16 @@ impl PyRdmaBuffer { ) -> PyResult { let (local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id); PyPythonTask::new(async move { - let local_buffer = instance_dispatch!(client, |cx_instance| { - local_owner_ref - .request_buffer_deprecated(cx_instance, addr, size) - .await? - }); - instance_dispatch!(&client, |cx_instance| { - local_buffer - .read_into(cx_instance, buffer, timeout) - .await - .map_err(|e| { - PyException::new_err(format!("failed to write from buffer: {}", e)) - })? - }); - instance_dispatch!(client, |cx_instance| { - local_owner_ref - .release_buffer_deprecated(cx_instance, local_buffer) - .await? - }); + let local_buffer = local_owner_ref + .request_buffer_deprecated(client.deref(), addr, size) + .await?; + local_buffer + .read_into(client.deref(), buffer, timeout) + .await + .map_err(|e| PyException::new_err(format!("failed to write from buffer: {}", e)))?; + local_owner_ref + .release_buffer_deprecated(client.deref(), local_buffer) + .await?; Ok(()) }) } @@ -249,13 +232,10 @@ impl PyRdmaBuffer { ) -> PyResult { let (_local_owner_ref, buffer) = setup_rdma_context(self, local_proc_id); PyPythonTask::new(async move { - // Call the drop method on the buffer to release remote handles - instance_dispatch!(client, |cx_instance| { - buffer - .drop_buffer(cx_instance) - .await - .map_err(|e| PyException::new_err(format!("Failed to drop buffer: {}", e)))? - }); + buffer + .drop_buffer(client.deref()) + .await + .map_err(|e| PyException::new_err(format!("Failed to drop buffer: {}", e)))?; Ok(()) }) } @@ -298,14 +278,12 @@ impl PyRdmaManager { PyPythonTask::new(async move { // Spawns the `RdmaManagerActor` on the target proc_mesh. // This allows the `RdmaController` to run on any node while real RDMA operations occur on appropriate hardware. - let actor_mesh = instance_dispatch!(client, |cx| { - tracked_proc_mesh - // Pass None to use default config - RdmaManagerActor will use default IbverbsConfig - // TODO - make IbverbsConfig configurable - .spawn::(cx, "rdma_manager", &None) - .await - .map_err(|err| PyException::new_err(err.to_string()))? - }); + let actor_mesh = tracked_proc_mesh + // Pass None to use default config - RdmaManagerActor will use default IbverbsConfig + // TODO - make IbverbsConfig configurable + .spawn::(client.deref(), "rdma_manager", &None) + .await + .map_err(|err| PyException::new_err(err.to_string()))?; // Use placeholder device name since actual device is determined on remote node Ok(Some(PyRdmaManager { @@ -316,14 +294,12 @@ impl PyRdmaManager { } else { let proc_mesh = proc_mesh.downcast::()?.borrow().mesh_ref()?; PyPythonTask::new(async move { - let actor_mesh = instance_dispatch!(client, |cx| { - proc_mesh - // Pass None to use default config - RdmaManagerActor will use default IbverbsConfig - // TODO - make IbverbsConfig configurable - .spawn_service::(cx, "rdma_manager", &None) - .await - .map_err(|err| PyException::new_err(err.to_string()))? - }); + let actor_mesh = proc_mesh + // Pass None to use default config - RdmaManagerActor will use default IbverbsConfig + // TODO - make IbverbsConfig configurable + .spawn_service::(client.deref(), "rdma_manager", &None) + .await + .map_err(|err| PyException::new_err(err.to_string()))?; let actor_mesh = RootActorMesh::from(actor_mesh); let actor_mesh = SharedCell::from(actor_mesh); diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 1d7221f4c..6feada9a4 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -72,6 +72,7 @@ Region, Shape, ) +from monarch._rust_bindings.monarch_hyperactor.supervision import MeshFailure from monarch._rust_bindings.monarch_hyperactor.v1.logging import log_endpoint_exception from monarch._rust_bindings.monarch_hyperactor.value_mesh import ( ValueMesh as HyValueMesh, @@ -942,6 +943,16 @@ def _process(self, msg: PythonMessage) -> Tuple[int, R]: MESSAGES_HANDLED: Counter = METER.create_counter("py_mesages_handled") +@dataclass +class ActorInitArgs: + Class: Type["Actor"] + proc_mesh: Optional["ProcMesh"] + controller_controller: Optional["_ControllerController"] + name: str + creator: Optional[CreatorInstance] + args: Tuple[Any, ...] + + class _Actor: """ This is the message handling implementation of a Python actor. @@ -985,14 +996,16 @@ async def handle( match method: case MethodSpecifier.Init(): ins = ctx.actor_instance - ( - Class, - ins.proc_mesh, - ins._controller_controller, - ins.name, - ins.creator, - *args, - ) = args + (args,) = args + init_args = cast(ActorInitArgs, args) + Class = init_args.Class + ins.proc_mesh = cast("ProcMesh", init_args.proc_mesh) + ins._controller_controller = cast( + "_ControllerController", init_args.controller_controller + ) + ins.name = init_args.name + ins.creator = init_args.creator + args = init_args.args ins.rank = ctx.message_rank ins.class_name = f"{Class.__module__}.{Class.__qualname__}" try: @@ -1365,12 +1378,14 @@ async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None: send( ep, ( - mesh._class, - proc_mesh, - controller_controller, - name, - context().actor_instance._as_creator(), - *args, + ActorInitArgs( + cast(Type[Actor], mesh._class), + proc_mesh, + controller_controller, + name, + context().actor_instance._as_creator(), + args, + ), ), kwargs, ) @@ -1459,3 +1474,22 @@ def current_rank() -> Point: def current_size() -> Dict[str, int]: r = context().message_rank.extent return {k: r[k] for k in r} + + +class RootClientActor(Actor): + name: str = "client" + + def __supervise__(self, failure: MeshFailure) -> object: + from monarch.actor import unhandled_fault_hook # pyre-ignore + + unhandled_fault_hook(failure) # pyre-ignore + return True + + @staticmethod + def _pickled_init_args() -> FrozenBuffer: + args = ( + ActorInitArgs(RootClientActor, None, None, RootClientActor.name, None, ()), + ) + kwargs = {} + _, buffer = flatten((args, kwargs), _is_ref_or_mailbox) + return buffer diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 47cafffa5..ade60362a 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -111,6 +111,7 @@ async def call_value(self, c: Counter) -> int: @pytest.mark.timeout(60) async def test_choose(): + print(f"THIS PID {os.getpid()}") proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2}) v = proc.spawn("counter", Counter, 3) i = proc.spawn("indirect", Indirect)