From 7e434241764e8d45d7ff4daaab9b24624991fc26 Mon Sep 17 00:00:00 2001 From: Marius Eriksen Date: Tue, 7 Oct 2025 11:44:36 -0700 Subject: [PATCH] [hyperactor] mesh: v0 shims for v1 meshes In this change, we implement: ``` From for ProcMesh ``` and ``` From> for RootActorMesh ``` This will allow us to temporarily use v1 instances through a v0 API, while we transition the usage sites. Differential Revision: [D84081478](https://our.internmc.facebook.com/intern/diff/D84081478/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84081478/)! [ghstack-poisoned] --- hyperactor_mesh/src/actor_mesh.rs | 325 ++++++++++++--- hyperactor_mesh/src/alloc/sim.rs | 4 +- hyperactor_mesh/src/comm.rs | 14 +- hyperactor_mesh/src/proc_mesh.rs | 430 ++++++++++++++------ hyperactor_mesh/src/reference.rs | 3 + hyperactor_mesh/src/v1/actor_mesh.rs | 17 +- hyperactor_mesh/src/v1/proc_mesh.rs | 16 +- monarch_hyperactor/src/code_sync/manager.rs | 7 +- monarch_hyperactor/src/code_sync/rsync.rs | 9 +- monarch_hyperactor/src/proc_mesh.rs | 20 +- ndslice/src/shape.rs | 8 + 11 files changed, 645 insertions(+), 208 deletions(-) diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 22e2253a5..f09f5c8bb 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -10,6 +10,7 @@ use std::collections::BTreeSet; use std::ops::Deref; +use std::sync::OnceLock; use async_trait::async_trait; use hyperactor::Actor; @@ -39,6 +40,7 @@ use ndslice::Selection; use ndslice::Shape; use ndslice::ShapeError; use ndslice::SliceError; +use ndslice::View; use ndslice::reshape::Limit; use ndslice::reshape::ReshapeError; use ndslice::reshape::ReshapeSliceExt; @@ -47,6 +49,7 @@ use ndslice::selection; use ndslice::selection::EvalOpts; use ndslice::selection::ReifySlice; use ndslice::selection::normal; +use ndslice::view::ViewExt; use serde::Deserialize; use serde::Serialize; use serde_multipart::Part; @@ -62,7 +65,8 @@ use crate::metrics; use crate::proc_mesh::ProcMesh; use crate::reference::ActorMeshId; use crate::reference::ActorMeshRef; -use crate::reference::ProcMeshId; +use crate::sel; +use crate::v1; declare_attrs! { /// Which mesh this message was cast to. Used for undeliverable message @@ -199,9 +203,21 @@ pub trait ActorMesh: Mesh { message: M, ) -> Result<(), CastError> where - Self::Actor: RemoteHandles>, - M: Castable + RemoteMessage, + Self::Actor: RemoteHandles + RemoteHandles>, + M: Castable + RemoteMessage + Clone, { + if let Some(v1) = self.v1() { + if !selection::structurally_equal(&selection, &sel!(*)) { + return Err(CastError::SelectionNotSupported(format!( + "ActorMesh::cast: selection {} not supported; for v1 meshes supports only universal selection", + selection + ))); + } + return v1 + .cast(cx, message) + .map_err(anyhow::Error::from) + .map_err(CastError::from); + } actor_mesh_cast::( cx, // actor context self.id(), // actor mesh id (destination mesh) @@ -224,10 +240,20 @@ pub trait ActorMesh: Mesh { } /// Iterate over all `ActorRef` in this mesh. - fn iter_actor_refs(&self) -> impl Iterator> { + fn iter_actor_refs(&self) -> Box>> { + if let Some(v1) = self.v1() { + // We collect() here to ensure that the data are owned. Since this is a short-lived + // shim, we'll live with it. + return Box::new( + v1.iter() + .map(|(_point, actor_ref)| actor_ref.clone()) + .collect::>() + .into_iter(), + ); + } let gang: GangRef = GangId(self.proc_mesh().world_id().clone(), self.name().to_string()).into(); - self.shape().slice().iter().map(move |rank| gang.rank(rank)) + Box::new(self.shape().slice().iter().map(move |rank| gang.rank(rank))) } async fn stop(&self) -> Result<(), anyhow::Error> { @@ -237,14 +263,14 @@ pub trait ActorMesh: Mesh { /// Get a serializeable reference to this mesh similar to ActorHandle::bind fn bind(&self) -> ActorMeshRef { ActorMeshRef::attest( - ActorMeshId::V0( - ProcMeshId(self.world_id().to_string()), - self.name().to_string(), - ), + self.id(), self.shape().clone(), self.proc_mesh().comm_actor().clone(), ) } + + /// Retrieves the v1 mesh for this v0 ActorMesh, if it is available. + fn v1(&self) -> Option>; } /// Abstracts over shared and borrowed references to a [`ProcMesh`]. @@ -276,12 +302,40 @@ impl Deref for ProcMeshRef<'_> { /// `ActorRef` handles (see `ranks`), and `ActorRef` is only /// defined for `A: RemoteActor`. pub struct RootActorMesh<'a, A: RemoteActor> { - proc_mesh: ProcMeshRef<'a>, - name: String, - pub(crate) ranks: Vec>, // temporary until we remove `ArcActorMesh`. - // The receiver of supervision events. It is None if it has been transferred to - // an actor event observer. - actor_supervision_rx: Option>, + inner: ActorMeshKind<'a, A>, + shape: OnceLock, + proc_mesh: OnceLock, + name: OnceLock, +} + +enum ActorMeshKind<'a, A: RemoteActor> { + V0 { + proc_mesh: ProcMeshRef<'a>, + name: String, + ranks: Vec>, // temporary until we remove `ArcActorMesh`. + // The receiver of supervision events. It is None if it has been transferred to + // an actor event observer. + actor_supervision_rx: Option>, + }, + + V1(v1::ActorMeshRef), +} + +impl<'a, A: RemoteActor> From> for RootActorMesh<'a, A> { + fn from(actor_mesh: v1::ActorMeshRef) -> Self { + Self { + inner: ActorMeshKind::V1(actor_mesh), + shape: OnceLock::new(), + proc_mesh: OnceLock::new(), + name: OnceLock::new(), + } + } +} + +impl<'a, A: RemoteActor> From> for RootActorMesh<'a, A> { + fn from(actor_mesh: v1::ActorMesh) -> Self { + actor_mesh.detach().into() + } } impl<'a, A: RemoteActor> RootActorMesh<'a, A> { @@ -292,10 +346,24 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { ranks: Vec>, ) -> Self { Self { - proc_mesh: ProcMeshRef::Borrowed(proc_mesh), - name, - ranks, - actor_supervision_rx: Some(actor_supervision_rx), + inner: ActorMeshKind::V0 { + proc_mesh: ProcMeshRef::Borrowed(proc_mesh), + name, + ranks, + actor_supervision_rx: Some(actor_supervision_rx), + }, + shape: OnceLock::new(), + proc_mesh: OnceLock::new(), + name: OnceLock::new(), + } + } + + pub(crate) fn new_v1(actor_mesh: v1::ActorMeshRef) -> Self { + Self { + inner: ActorMeshKind::V1(actor_mesh), + shape: OnceLock::new(), + proc_mesh: OnceLock::new(), + name: OnceLock::new(), } } @@ -306,27 +374,50 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { ranks: Vec>, ) -> Self { Self { - proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)), - name, - ranks, - actor_supervision_rx: Some(actor_supervision_rx), + inner: ActorMeshKind::V0 { + proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)), + name, + ranks, + actor_supervision_rx: Some(actor_supervision_rx), + }, + shape: OnceLock::new(), + proc_mesh: OnceLock::new(), + name: OnceLock::new(), } } /// Open a port on this ActorMesh. pub fn open_port(&self) -> (PortHandle, PortReceiver) { - self.proc_mesh.client().open_port() + match &self.inner { + ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh.client().open_port(), + ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"), + } } /// An event stream of actor events. Each RootActorMesh can produce only one such /// stream, returning None after the first call. pub fn events(&mut self) -> Option { - self.actor_supervision_rx - .take() - .map(|actor_supervision_rx| ActorSupervisionEvents { + match &mut self.inner { + ActorMeshKind::V0 { actor_supervision_rx, - mesh_id: self.id(), - }) + .. + } => actor_supervision_rx + .take() + .map(|actor_supervision_rx| ActorSupervisionEvents { + actor_supervision_rx, + mesh_id: self.id(), + }), + ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"), + } + } + + /// Access the ranks field (temporary until we remove `ArcActorMesh`). + #[cfg(test)] + pub(crate) fn ranks(&self) -> &Vec> { + match &self.inner { + ActorMeshKind::V0 { ranks, .. } => ranks, + ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"), + } } } @@ -361,7 +452,10 @@ impl<'a, A: RemoteActor> Mesh for RootActorMesh<'a, A> { 'a: 'b; fn shape(&self) -> &Shape { - self.proc_mesh.shape() + self.shape.get_or_init(|| match &self.inner { + ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh.shape().clone(), + ActorMeshKind::V1(actor_mesh) => actor_mesh.region().into(), + }) } fn select>( @@ -373,11 +467,19 @@ impl<'a, A: RemoteActor> Mesh for RootActorMesh<'a, A> { } fn get(&self, rank: usize) -> Option> { - self.ranks.get(rank).cloned() + match &self.inner { + ActorMeshKind::V0 { ranks, .. } => ranks.get(rank).cloned(), + ActorMeshKind::V1(actor_mesh) => actor_mesh.get(rank), + } } fn id(&self) -> Self::Id { - ActorMeshId::V0(self.proc_mesh.id(), self.name.clone()) + match &self.inner { + ActorMeshKind::V0 { + proc_mesh, name, .. + } => ActorMeshId::V0(proc_mesh.id(), name.clone()), + ActorMeshKind::V1(actor_mesh) => ActorMeshId::V1(actor_mesh.name().clone()), + } } } @@ -385,11 +487,28 @@ impl ActorMesh for RootActorMesh<'_, A> { type Actor = A; fn proc_mesh(&self) -> &ProcMesh { - &self.proc_mesh + match &self.inner { + ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh, + ActorMeshKind::V1(actor_mesh) => self + .proc_mesh + .get_or_init(|| actor_mesh.proc_mesh().clone().into()), + } } fn name(&self) -> &str { - &self.name + match &self.inner { + ActorMeshKind::V0 { name, .. } => name, + ActorMeshKind::V1(actor_mesh) => { + self.name.get_or_init(|| actor_mesh.name().to_string()) + } + } + } + + fn v1(&self) -> Option> { + match &self.inner { + ActorMeshKind::V0 { .. } => None, + ActorMeshKind::V1(actor_mesh) => Some(actor_mesh.clone()), + } } } @@ -439,11 +558,11 @@ impl ActorMesh for SlicedActorMesh<'_, A> { type Actor = A; fn proc_mesh(&self) -> &ProcMesh { - &self.0.proc_mesh + self.0.proc_mesh() } fn name(&self) -> &str { - &self.0.name + self.0.name() } #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`. @@ -462,6 +581,12 @@ impl ActorMesh for SlicedActorMesh<'_, A> { /*root_mesh_shape=*/ self.0.shape(), ) } + + fn v1(&self) -> Option> { + self.0 + .v1() + .map(|actor_mesh| actor_mesh.subset(self.shape().into()).unwrap()) + } } /// The type of error of casting operations. @@ -473,6 +598,9 @@ pub enum CastError { #[error("send on rank {0}: {1}")] MailboxSenderError(usize, MailboxSenderError), + #[error("unsupported selection: {0}")] + SelectionNotSupported(String), + #[error(transparent)] RootMailboxSenderError(#[from] MailboxSenderError), @@ -503,6 +631,7 @@ pub(crate) mod test_util { use anyhow::ensure; use hyperactor::Context; use hyperactor::Handler; + use hyperactor::Instance; use hyperactor::PortRef; use ndslice::extent; @@ -625,8 +754,8 @@ pub(crate) mod test_util { ], )] pub struct ProxyActor { - proc_mesh: Arc, - actor_mesh: RootActorMesh<'static, TestActor>, + proc_mesh: &'static Arc, + actor_mesh: Option>, } impl fmt::Debug for ProxyActor { @@ -664,13 +793,16 @@ pub(crate) mod test_util { .unwrap(); let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap()); let leaked: &'static Arc = Box::leak(Box::new(proc_mesh)); - let actor_mesh: RootActorMesh<'static, TestActor> = - leaked.spawn("echo", &()).await.unwrap(); Ok(Self { - proc_mesh: Arc::clone(leaked), - actor_mesh, + proc_mesh: leaked, + actor_mesh: None, }) } + + async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { + self.actor_mesh = Some(self.proc_mesh.spawn(this, "echo", &()).await?); + Ok(()) + } } #[async_trait] @@ -679,7 +811,7 @@ pub(crate) mod test_util { if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() { // test_proxy_mesh - let actor = self.actor_mesh.get(0).unwrap(); + let actor = self.actor_mesh.as_ref().unwrap().get(0).unwrap(); // For now, we reply directly to the client. // We will support directly wiring up the meshes later. @@ -692,7 +824,7 @@ pub(crate) mod test_util { } else { // test_router_undeliverable_return - let actor: ActorRef<_> = self.actor_mesh.get(0).unwrap(); + let actor: ActorRef<_> = self.actor_mesh.as_ref().unwrap().get(0).unwrap(); let (tx, mut rx) = cx.open_port::(); actor.send(cx, Echo(message.0, tx.bind()))?; @@ -766,8 +898,9 @@ mod tests { }) .await .unwrap(); + let instance = $crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); - let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh.spawn("proxy", &()).await.unwrap(); + let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh.spawn(&instance, "proxy", &()).await.unwrap(); let proxy_actor = actor_mesh.get(0).unwrap(); let (tx, mut rx) = actor_mesh.open_port::(); proxy_actor.send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind())).unwrap(); @@ -791,8 +924,9 @@ mod tests { .await .unwrap(); + let instance = $crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); - let actor_mesh: RootActorMesh = proc_mesh.spawn("echo", &()).await.unwrap(); + let actor_mesh: RootActorMesh = proc_mesh.spawn(&instance, "echo", &()).await.unwrap(); let (reply_handle, mut reply_receiver) = actor_mesh.open_port(); actor_mesh .cast(proc_mesh.client(), sel!(*), Echo("Hello".to_string(), reply_handle.bind())) @@ -817,12 +951,13 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mesh = ProcMesh::allocate(alloc).await.unwrap(); let (undeliverable_msg_tx, _) = mesh.client().open_port(); let ping_pong_actor_params = PingPongActorParams::new(Some(undeliverable_msg_tx.bind()), None); let actor_mesh: RootActorMesh = mesh - .spawn::("ping-pong", &ping_pong_actor_params) + .spawn::(&instance, "ping-pong", &ping_pong_actor_params) .await .unwrap(); @@ -855,10 +990,11 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let (undeliverable_tx, _undeliverable_rx) = proc_mesh.client().open_port(); let params = PingPongActorParams::new(Some(undeliverable_tx.bind()), None); - let actor_mesh = proc_mesh.spawn::("pingpong", ¶ms).await.unwrap(); + let actor_mesh = proc_mesh.spawn::(&instance, "pingpong", ¶ms).await.unwrap(); let slice = actor_mesh.shape().slice(); let mut futures = Vec::new(); @@ -900,8 +1036,9 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); - let actor_mesh: RootActorMesh = proc_mesh.spawn("echo", &()).await.unwrap(); + let actor_mesh: RootActorMesh = proc_mesh.spawn(&instance, "echo", &()).await.unwrap(); let dont_simulate_error = true; let (reply_handle, mut reply_receiver) = actor_mesh.open_port(); actor_mesh @@ -943,8 +1080,9 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); - let actor_mesh: RootActorMesh = proc_mesh.spawn("echo", &()).await.unwrap(); + let actor_mesh: RootActorMesh = proc_mesh.spawn(&instance, "echo", &()).await.unwrap(); // Bounce the message through all actors and return it to the sender (us). let mut hops: VecDeque<_> = actor_mesh.iter().map(|actor| actor.port()).collect(); @@ -964,6 +1102,7 @@ mod tests { #[tokio::test] async fn test_inter_proc_mesh_comms() { let mut meshes = Vec::new(); + let instance = crate::v1::testing::instance().await; for _ in 0..2 { let alloc = $allocator .allocate(AllocSpec { @@ -977,7 +1116,7 @@ mod tests { let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap()); let proc_mesh_clone = Arc::clone(&proc_mesh); - let actor_mesh : RootActorMesh = proc_mesh_clone.spawn("echo", &()).await.unwrap(); + let actor_mesh : RootActorMesh = proc_mesh_clone.spawn(&instance, "echo", &()).await.unwrap(); meshes.push((proc_mesh, actor_mesh)); } @@ -1024,11 +1163,12 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client()); let params = CastTestActorParams{ forward_port: tx.bind() }; - let actor_mesh: RootActorMesh = proc_mesh.spawn("actor", ¶ms).await.unwrap(); + let actor_mesh: RootActorMesh = proc_mesh.spawn(&instance, "actor", ¶ms).await.unwrap(); actor_mesh.cast(proc_mesh.client(), sel!(*), CastTestMessage::Forward("abc".to_string())).unwrap(); @@ -1086,11 +1226,12 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mesh = ProcMesh::allocate(alloc).await.unwrap(); let (reply_port_handle, mut reply_port_receiver) = mesh.client().open_port::(); let reply_port = reply_port_handle.bind(); - let actor_mesh: RootActorMesh = mesh.spawn("test", &()).await.unwrap(); + let actor_mesh: RootActorMesh = mesh.spawn(&instance, "test", &()).await.unwrap(); let actor_ref = actor_mesh.get(0).unwrap(); let mut headers = Attrs::new(); set_cast_info_on_headers(&mut headers, extent.point_of_rank(0).unwrap(), mesh.client().self_id().clone()); @@ -1150,6 +1291,7 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let monkey = alloc.chaos_monkey(); let mut mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut events = mesh.events().unwrap(); @@ -1159,7 +1301,7 @@ mod tests { None, ); let actor_mesh: RootActorMesh = mesh - .spawn::("ping-pong", &ping_pong_actor_params) + .spawn::(&instance, "ping-pong", &ping_pong_actor_params) .await .unwrap(); @@ -1219,13 +1361,14 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let stop = alloc.stopper(); let mut mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut events = mesh.events().unwrap(); let actor_mesh = mesh - .spawn::("reply-then-fail", &()) + .spawn::(&instance, "reply-then-fail", &()) .await .unwrap(); @@ -1287,6 +1430,7 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mesh = ProcMesh::allocate(alloc).await.unwrap(); let ping_pong_actor_params = PingPongActorParams::new( @@ -1294,12 +1438,12 @@ mod tests { None, ); let mesh_one: RootActorMesh = mesh - .spawn::("mesh_one", &ping_pong_actor_params) + .spawn::(&instance, "mesh_one", &ping_pong_actor_params) .await .unwrap(); let mesh_two: RootActorMesh = mesh - .spawn::("mesh_two", &ping_pong_actor_params) + .spawn::(&instance, "mesh_two", &ping_pong_actor_params) .await .unwrap(); @@ -1399,10 +1543,11 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut proc_events = proc_mesh.events().unwrap(); let actor_mesh: RootActorMesh = - proc_mesh.spawn("ingest", &()).await.unwrap(); + proc_mesh.spawn(&instance, "ingest", &()).await.unwrap(); let (reply_handle, mut reply_receiver) = actor_mesh.open_port(); let dest = actor_mesh.get(0).unwrap(); @@ -1487,10 +1632,11 @@ mod tests { // SAFETY: Not multithread safe. unsafe { std::env::set_var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK", "1") }; + let instance = crate::v1::testing::instance().await; let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut proc_events = proc_mesh.events().unwrap(); let mut actor_mesh: RootActorMesh<'_, ProxyActor> = - { proc_mesh.spawn("proxy", &()).await.unwrap() }; + { proc_mesh.spawn(&instance, "proxy", &()).await.unwrap() }; let mut actor_events = actor_mesh.events().unwrap(); let proxy_actor = actor_mesh.get(0).unwrap(); @@ -1509,7 +1655,7 @@ mod tests { ); assert_eq!( actor_events.next().await.unwrap().actor_id.name(), - &actor_mesh.name + actor_mesh.name(), ); } } @@ -1645,12 +1791,13 @@ mod tests { transport: ChannelTransport::Local })) .unwrap(); + let instance = runtime.block_on(crate::v1::testing::instance()); let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap(); let addr = ChannelAddr::any(ChannelTransport::Unix); let actor_mesh: RootActorMesh = - runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap(); + runtime.block_on(proc_mesh.spawn(&instance, "echo", &addr)).unwrap(); let mut runner = TestRunner::default(); let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0) @@ -1677,12 +1824,13 @@ mod tests { transport: ChannelTransport::Local })) .unwrap(); + let instance = runtime.block_on(crate::v1::testing::instance()); let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap(); let addr = ChannelAddr::any(ChannelTransport::Unix); let actor_mesh: RootActorMesh = - runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap(); + runtime.block_on(proc_mesh.spawn(&instance, "echo", &addr)).unwrap(); let first_label = extent.labels().first().unwrap(); @@ -1750,12 +1898,13 @@ mod tests { transport: ChannelTransport::Local })) .unwrap(); + let instance = runtime.block_on(crate::v1::testing::instance()); let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap(); let addr = ChannelAddr::any(ChannelTransport::Unix); let actor_mesh: RootActorMesh = - runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap(); + runtime.block_on(proc_mesh.spawn(&instance, "echo", &addr)).unwrap(); let mut runner = TestRunner::default(); let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0) @@ -1772,4 +1921,54 @@ mod tests { } } } + + mod shim { + use std::collections::HashSet; + + use hyperactor::context::Mailbox; + use ndslice::Extent; + use ndslice::extent; + + use super::*; + use crate::sel; + + #[tokio::test] + async fn test_basic() { + let instance = v1::testing::instance().await; + let host_mesh = v1::testing::host_mesh(extent!(host = 4)).await; + let proc_mesh = host_mesh + .spawn(instance, "test", Extent::unity()) + .await + .unwrap(); + let actor_mesh = proc_mesh + .spawn::(instance, "test", &()) + .await + .unwrap(); + + let actor_mesh_v0: RootActorMesh<'_, _> = actor_mesh.clone().into(); + + let (cast_info, mut cast_info_rx) = instance.mailbox().open_port(); + actor_mesh_v0 + .cast( + instance, + sel!(*), + v1::testactor::GetCastInfo { + cast_info: cast_info.bind(), + }, + ) + .unwrap(); + + let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect(); + while !point_to_actor.is_empty() { + let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap(); + let key = (point, origin_actor_ref); + assert!( + point_to_actor.remove(&key), + "key {:?} not present or removed twice", + key + ); + assert_eq!(&sender_actor_id, instance.self_id()); + } + } + } } diff --git a/hyperactor_mesh/src/alloc/sim.rs b/hyperactor_mesh/src/alloc/sim.rs index 652347331..983425f6f 100644 --- a/hyperactor_mesh/src/alloc/sim.rs +++ b/hyperactor_mesh/src/alloc/sim.rs @@ -199,11 +199,13 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let handle = hyperactor::simnet::simnet_handle().unwrap(); - let actor_mesh: RootActorMesh = proc_mesh.spawn("echo", &()).await.unwrap(); + let actor_mesh: RootActorMesh = + proc_mesh.spawn(&instance, "echo", &()).await.unwrap(); let actors = actor_mesh.iter_actor_refs().collect::>(); assert_eq!( handle.sample_latency( diff --git a/hyperactor_mesh/src/comm.rs b/hyperactor_mesh/src/comm.rs index 701131b12..45b7968c8 100644 --- a/hyperactor_mesh/src/comm.rs +++ b/hyperactor_mesh/src/comm.rs @@ -782,9 +782,9 @@ mod tests { let params = TestActorParams { forward_port: tx.bind(), }; - let actor_mesh = proc_mesh - .clone() - .spawn::(dest_actor_name, ¶ms) + let instance = crate::v1::testing::instance().await; + let actor_mesh = Arc::clone(&proc_mesh) + .spawn::(&instance, dest_actor_name, ¶ms) .await .unwrap(); @@ -887,7 +887,7 @@ mod tests { // Reply from each dest actor. The replies should be received by client. { for (dest_actor, (reply_to1, reply_to2)) in - actor_mesh.ranks.iter().zip(reply_tos.iter()) + actor_mesh.ranks().iter().zip(reply_tos.iter()) { let rank = dest_actor.actor_id().rank() as u64; reply_to1.send(proc_mesh_client, rank).unwrap(); @@ -911,7 +911,7 @@ mod tests { let n = 100; let mut expected2: HashMap> = hashmap! {}; for (dest_actor, (_reply_to1, reply_to2)) in - actor_mesh.ranks.iter().zip(reply_tos.iter()) + actor_mesh.ranks().iter().zip(reply_tos.iter()) { let rank = dest_actor.actor_id().rank(); let mut sent2 = vec![]; @@ -932,7 +932,7 @@ mod tests { let mut received2: HashMap> = hashmap! {}; - for _ in 0..(n * actor_mesh.ranks.len()) { + for _ in 0..(n * actor_mesh.ranks().len()) { let my_reply = reply2_rx.recv().await.unwrap(); received2 .entry(my_reply.sender.rank()) @@ -983,7 +983,7 @@ mod tests { let mut sum = 0; let n = 100; for (dest_actor, (reply_to1, _reply_to2)) in - actor_mesh.ranks.iter().zip(reply_tos.iter()) + actor_mesh.ranks().iter().zip(reply_tos.iter()) { let rank = dest_actor.actor_id().rank(); for i in 0..n { diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 6774353d1..3f7f7484e 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -51,6 +51,8 @@ use hyperactor::supervision::ActorSupervisionEvent; use ndslice::Range; use ndslice::Shape; use ndslice::ShapeError; +use ndslice::View; +use ndslice::ViewExt; use strum::AsRefStr; use tokio::sync::mpsc; use tracing::Instrument; @@ -77,6 +79,7 @@ use crate::reference::ActorMeshId; use crate::reference::ProcMeshId; use crate::router; use crate::shortuuid::ShortUuid; +use crate::v1; pub mod mesh_agent; @@ -205,20 +208,30 @@ pub fn global_root_client() -> &'static Instance<()> { } type ActorEventRouter = Arc>>; + /// A ProcMesh maintains a mesh of procs whose lifecycles are managed by /// an allocator. pub struct ProcMesh { - // The underlying set of events. It is None if it has been transferred to - // a proc event observer. - event_state: Option, - actor_event_router: ActorEventRouter, - shape: Shape, - ranks: Vec<(ShortUuid, ProcId, ChannelAddr, ActorRef)>, - #[allow(dead_code)] // will be used in subsequent diff - client_proc: Proc, - client: Instance<()>, - comm_actors: Vec>, - world_id: WorldId, + inner: ProcMeshKind, + shape: OnceLock, +} + +enum ProcMeshKind { + V0 { + // The underlying set of events. It is None if it has been transferred to + // a proc event observer. + event_state: Option, + actor_event_router: ActorEventRouter, + shape: Shape, + ranks: Vec<(ShortUuid, ProcId, ChannelAddr, ActorRef)>, + #[allow(dead_code)] // will be used in subsequent diff + client_proc: Proc, + client: Instance<()>, + comm_actors: Vec>, + world_id: WorldId, + }, + + V1(v1::ProcMeshRef), } struct EventState { @@ -226,6 +239,15 @@ struct EventState { supervision_events: PortReceiver, } +impl From for ProcMesh { + fn from(proc_mesh: v1::ProcMeshRef) -> Self { + ProcMesh { + inner: ProcMeshKind::V1(proc_mesh), + shape: OnceLock::new(), + } + } +} + impl ProcMesh { #[hyperactor::instrument(fields(name = "ProcMesh::allocate"))] pub async fn allocate( @@ -441,27 +463,30 @@ impl ProcMesh { ); Ok(Self { - event_state: Some(EventState { - alloc, - supervision_events, - }), - actor_event_router: Arc::new(DashMap::new()), - shape, - ranks: running - .into_iter() - .map( - |AllocatedProc { - create_key, - proc_id, - addr, - mesh_agent, - }| (create_key, proc_id, addr, mesh_agent), - ) - .collect(), - client_proc, - client, - comm_actors, - world_id, + inner: ProcMeshKind::V0 { + event_state: Some(EventState { + alloc, + supervision_events, + }), + actor_event_router: Arc::new(DashMap::new()), + shape, + ranks: running + .into_iter() + .map( + |AllocatedProc { + create_key, + proc_id, + addr, + mesh_agent, + }| (create_key, proc_id, addr, mesh_agent), + ) + .collect(), + client_proc, + client, + comm_actors, + world_id, + }, + shape: OnceLock::new(), }) } @@ -534,13 +559,35 @@ impl ProcMesh { .collect()) } - fn agents(&self) -> impl Iterator> + '_ { - self.ranks.iter().map(|(_, _, _, agent)| agent.clone()) + fn agents(&self) -> Box> + '_ + Send> { + match &self.inner { + ProcMeshKind::V0 { ranks, .. } => { + Box::new(ranks.iter().map(|(_, _, _, agent)| agent.clone())) + } + ProcMeshKind::V1(proc_mesh) => Box::new( + proc_mesh + .agent_mesh() + .iter() + .map(|(_point, agent)| agent.clone()) + // We need to collect here so that we can return an iterator + // that fully owns the data and does not reference temporary + // values. + // + // Because this is a shim that we expect to be short-lived, + // we'll leave this hack as is; a proper solution here would + // be to have implement an owning iterator (into_iter) for views. + .collect::>() + .into_iter(), + ), + } } /// Return the comm actor to which casts should be forwarded. pub(crate) fn comm_actor(&self) -> &ActorRef { - &self.comm_actors[0] + match &self.inner { + ProcMeshKind::V0 { comm_actors, .. } => &comm_actors[0], + ProcMeshKind::V1(proc_mesh) => proc_mesh.root_comm_actor().unwrap(), + } } /// Spawn an `ActorMesh` by launching the same actor type on all @@ -560,102 +607,145 @@ impl ProcMesh { /// cross proc boundaries when launching each actor. pub async fn spawn( &self, + cx: &impl context::Actor, actor_name: &str, params: &A::Params, ) -> Result, anyhow::Error> where A::Params: RemoteMessage, { - let (tx, rx) = mpsc::unbounded_channel::(); - { - // Instantiate supervision routing BEFORE spawning the actor mesh. - self.actor_event_router.insert(actor_name.to_string(), tx); - tracing::info!( - name = "router_insert", - actor_name = %actor_name, - "the length of the router is {}", self.actor_event_router.len(), - ); + match &self.inner { + ProcMeshKind::V0 { + actor_event_router, + client, + .. + } => { + let (tx, rx) = mpsc::unbounded_channel::(); + { + // Instantiate supervision routing BEFORE spawning the actor mesh. + actor_event_router.insert(actor_name.to_string(), tx); + tracing::info!( + name = "router_insert", + actor_name = %actor_name, + "the length of the router is {}", actor_event_router.len(), + ); + } + let root_mesh = RootActorMesh::new( + self, + actor_name.to_string(), + rx, + Self::spawn_on_procs::(client, self.agents(), actor_name, params).await?, + ); + Ok(root_mesh) + } + ProcMeshKind::V1(proc_mesh) => { + let actor_mesh = proc_mesh.spawn(cx, actor_name, params).await?; + Ok(RootActorMesh::new_v1(actor_mesh.detach())) + } } - let root_mesh = RootActorMesh::new( - self, - actor_name.to_string(), - rx, - Self::spawn_on_procs::(&self.client, self.agents(), actor_name, params).await?, - ); - Ok(root_mesh) } /// A client actor used to communicate with any member of this mesh. pub fn client(&self) -> &Instance<()> { - &self.client + match &self.inner { + ProcMeshKind::V0 { client, .. } => client, + ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client for v1::ProcMesh"), + } } pub fn client_proc(&self) -> &Proc { - &self.client_proc + match &self.inner { + ProcMeshKind::V0 { client_proc, .. } => client_proc, + ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client proc for v1::ProcMesh"), + } } pub fn proc_id(&self) -> &ProcId { - self.client_proc.proc_id() + self.client_proc().proc_id() } pub fn world_id(&self) -> &WorldId { - &self.world_id + match &self.inner { + ProcMeshKind::V0 { world_id, .. } => world_id, + ProcMeshKind::V1(_proc_mesh) => unimplemented!("no world_id for v1::ProcMesh"), + } } /// An event stream of proc events. Each ProcMesh can produce only one such /// stream, returning None after the first call. pub fn events(&mut self) -> Option { - self.event_state.take().map(|event_state| ProcEvents { - event_state, - ranks: self - .ranks - .iter() - .enumerate() - .map(|(rank, (create_key, proc_id, _addr, _mesh_agent))| { - (proc_id.clone(), (rank, create_key.clone())) - }) - .collect(), - actor_event_router: self.actor_event_router.clone(), - }) + match &mut self.inner { + ProcMeshKind::V0 { + event_state, + ranks, + actor_event_router, + .. + } => event_state.take().map(|event_state| ProcEvents { + event_state, + ranks: ranks + .iter() + .enumerate() + .map(|(rank, (create_key, proc_id, _addr, _mesh_agent))| { + (proc_id.clone(), (rank, create_key.clone())) + }) + .collect(), + actor_event_router: actor_event_router.clone(), + }), + ProcMeshKind::V1(_proc_mesh) => todo!(), + } } pub fn shape(&self) -> &Shape { - &self.shape + // We store the shape here, only because it isn't materialized in + // V1 meshes. + self.shape.get_or_init(|| match &self.inner { + ProcMeshKind::V0 { shape, .. } => shape.clone(), + ProcMeshKind::V1(proc_mesh) => proc_mesh.region().into(), + }) } /// Send stop actors message to all mesh agents for a specific mesh name #[hyperactor::observe_result("ProcMesh")] pub async fn stop_actor_by_name(&self, mesh_name: &str) -> Result<(), anyhow::Error> { - let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT); - let results = join_all(self.agents().map(|agent| async move { - let actor_id = ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0); - ( - actor_id.clone(), - agent - .clone() - .stop_actor(&self.client, actor_id, timeout.as_millis() as u64) - .await, - ) - })) - .await; - - for (actor_id, result) in results { - match result { - Ok(StopActorResult::Timeout) => { - tracing::warn!("timed out while stopping actor {}", actor_id); - } - Ok(StopActorResult::NotFound) => { - tracing::warn!("no actor {} on proc {}", actor_id, actor_id.proc_id()); - } - Ok(StopActorResult::Success) => { - tracing::info!("stopped actor {}", actor_id); - } - Err(e) => { - tracing::warn!("error stopping actor {}: {}", actor_id, e); + match &self.inner { + ProcMeshKind::V0 { client, .. } => { + let timeout = + hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT); + let results = join_all(self.agents().map(|agent| async move { + let actor_id = + ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0); + ( + actor_id.clone(), + agent + .clone() + .stop_actor(client, actor_id, timeout.as_millis() as u64) + .await, + ) + })) + .await; + + for (actor_id, result) in results { + match result { + Ok(StopActorResult::Timeout) => { + tracing::warn!("timed out while stopping actor {}", actor_id); + } + Ok(StopActorResult::NotFound) => { + tracing::warn!("no actor {} on proc {}", actor_id, actor_id.proc_id()); + } + Ok(StopActorResult::Success) => { + tracing::info!("stopped actor {}", actor_id); + } + Err(e) => { + tracing::warn!("error stopping actor {}: {}", actor_id, e); + } + } } + Ok(()) + } + ProcMeshKind::V1(_proc_mesh) => { + anyhow::bail!("kill actor by name unsupported for v1::ProcMesh") } } - Ok(()) } } @@ -882,6 +972,7 @@ pub trait SharedSpawnable { // `RemoteActor`: so we can hand back ActorRef in RootActorMesh async fn spawn( self, + cx: &impl context::Actor, actor_name: &str, params: &A::Params, ) -> Result, anyhow::Error> @@ -895,30 +986,41 @@ impl + Send + Sync + 'static> SharedSpawnable for D // `RemoteActor`: so we can hand back ActorRef in RootActorMesh async fn spawn( self, + cx: &impl context::Actor, actor_name: &str, params: &A::Params, ) -> Result, anyhow::Error> where A::Params: RemoteMessage, { - let (tx, rx) = mpsc::unbounded_channel::(); - { - // Instantiate supervision routing BEFORE spawning the actor mesh. - self.actor_event_router.insert(actor_name.to_string(), tx); - tracing::info!( - name = "router_insert", - actor_name = %actor_name, - "the length of the router is {}", self.actor_event_router.len(), - ); + match &self.deref().inner { + ProcMeshKind::V0 { + actor_event_router, + client, + .. + } => { + let (tx, rx) = mpsc::unbounded_channel::(); + { + // Instantiate supervision routing BEFORE spawning the actor mesh. + actor_event_router.insert(actor_name.to_string(), tx); + tracing::info!( + name = "router_insert", + actor_name = %actor_name, + "the length of the router is {}", actor_event_router.len(), + ); + } + let ranks = + ProcMesh::spawn_on_procs::(client, self.agents(), actor_name, params) + .await?; + Ok(RootActorMesh::new_shared( + self, + actor_name.to_string(), + rx, + ranks, + )) + } + ProcMeshKind::V1(proc_mesh) => todo!(), } - let ranks = - ProcMesh::spawn_on_procs::(&self.client, self.agents(), actor_name, params).await?; - Ok(RootActorMesh::new_shared( - self, - actor_name.to_string(), - rx, - ranks, - )) } } @@ -929,7 +1031,7 @@ impl Mesh for ProcMesh { type Sliced<'a> = SlicedProcMesh<'a>; fn shape(&self) -> &Shape { - &self.shape + ProcMesh::shape(self) } fn select>( @@ -941,11 +1043,17 @@ impl Mesh for ProcMesh { } fn get(&self, rank: usize) -> Option { - Some(self.ranks[rank].1.clone()) + match &self.inner { + ProcMeshKind::V0 { ranks, .. } => Some(ranks[rank].1.clone()), + ProcMeshKind::V1(proc_mesh) => proc_mesh.get(rank).map(|proc| proc.proc_id().clone()), + } } fn id(&self) -> Self::Id { - ProcMeshId(self.world_id().name().to_string()) + match &self.inner { + ProcMeshKind::V0 { world_id, .. } => ProcMeshId(world_id.name().to_string()), + ProcMeshKind::V1(proc_mesh) => ProcMeshId(proc_mesh.name().to_string()), + } } } @@ -957,13 +1065,22 @@ impl fmt::Display for ProcMesh { impl fmt::Debug for ProcMesh { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ProcMesh") - .field("shape", &self.shape()) - .field("ranks", &self.ranks) - .field("client_proc", &self.client_proc) - .field("client", &"") - // Skip the alloc field since it doesn't implement Debug - .finish() + match &self.inner { + ProcMeshKind::V0 { + shape, + ranks, + client_proc, + .. + } => f + .debug_struct("ProcMesh::V0") + .field("shape", shape) + .field("ranks", ranks) + .field("client_proc", client_proc) + .field("client", &"") + // Skip the alloc field since it doesn't implement Debug + .finish(), + ProcMeshKind::V1(proc_mesh) => fmt::Debug::fmt(proc_mesh, f), + } } } @@ -1083,7 +1200,12 @@ mod tests { let mut mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut events = mesh.events().unwrap(); - let mut actors = mesh.spawn::("failing", &()).await.unwrap(); + let instance = crate::v1::testing::instance().await; + + let mut actors = mesh + .spawn::(&instance, "failing", &()) + .await + .unwrap(); let mut actor_events = actors.events().unwrap(); actors @@ -1133,8 +1255,66 @@ mod tests { .unwrap(); let mesh = ProcMesh::allocate(alloc).await.unwrap(); - mesh.spawn::("dup", &()).await.unwrap(); - let result = mesh.spawn::("dup", &()).await; + let instance = crate::v1::testing::instance().await; + mesh.spawn::(&instance, "dup", &()) + .await + .unwrap(); + let result = mesh.spawn::(&instance, "dup", &()).await; assert!(result.is_err()); } + + mod shim { + use std::collections::HashSet; + + use hyperactor::context::Mailbox; + use ndslice::Extent; + use ndslice::Selection; + + use super::*; + use crate::sel; + + #[tokio::test] + async fn test_basic() { + let instance = v1::testing::instance().await; + let ext = extent!(host = 4); + let host_mesh = v1::testing::host_mesh(ext.clone()).await; + let proc_mesh = host_mesh + .spawn(instance, "test", Extent::unity()) + .await + .unwrap(); + let proc_mesh_v0: ProcMesh = proc_mesh.detach().into(); + + let actor_mesh = proc_mesh_v0 + .spawn::(instance, "test", &()) + .await + .unwrap(); + + let (cast_info, mut cast_info_rx) = instance.mailbox().open_port(); + actor_mesh + .cast( + instance, + sel!(*), + v1::testactor::GetCastInfo { + cast_info: cast_info.bind(), + }, + ) + .unwrap(); + + let mut point_to_actor: HashSet<_> = actor_mesh + .iter_actor_refs() + .enumerate() + .map(|(rank, actor_ref)| (ext.point_of_rank(rank).unwrap(), actor_ref)) + .collect(); + while !point_to_actor.is_empty() { + let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap(); + let key = (point, origin_actor_ref); + assert!( + point_to_actor.remove(&key), + "key {:?} not present or removed twice", + key + ); + assert_eq!(&sender_actor_id, instance.self_id()); + } + } + } } diff --git a/hyperactor_mesh/src/reference.rs b/hyperactor_mesh/src/reference.rs index 813b2a882..411fafe9f 100644 --- a/hyperactor_mesh/src/reference.rs +++ b/hyperactor_mesh/src/reference.rs @@ -336,9 +336,11 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let ping_proc_mesh = ProcMesh::allocate(alloc_ping).await.unwrap(); let ping_mesh: RootActorMesh = ping_proc_mesh .spawn( + &instance, "ping", &MeshPingPongActorParams { mesh_id: ActorMeshId::V0( @@ -356,6 +358,7 @@ mod tests { let pong_proc_mesh = ProcMesh::allocate(alloc_pong).await.unwrap(); let pong_mesh: RootActorMesh = pong_proc_mesh .spawn( + &instance, "pong", &MeshPingPongActorParams { mesh_id: ActorMeshId::V0( diff --git a/hyperactor_mesh/src/v1/actor_mesh.rs b/hyperactor_mesh/src/v1/actor_mesh.rs index f2b6d0318..089d54b4c 100644 --- a/hyperactor_mesh/src/v1/actor_mesh.rs +++ b/hyperactor_mesh/src/v1/actor_mesh.rs @@ -70,6 +70,11 @@ impl ActorMesh { current_ref, } } + + /// Detach this mesh from the lifetime of `self`, and return its reference. + pub(crate) fn detach(self) -> ActorMeshRef { + self.current_ref + } } impl Deref for ActorMesh { @@ -135,7 +140,7 @@ pub struct ActorMeshRef { _phantom: PhantomData, } -impl ActorMeshRef { +impl ActorMeshRef { /// Cast a message to all actors in this mesh. pub fn cast(&self, cx: &impl context::Actor, message: M) -> v1::Result<()> where @@ -182,7 +187,7 @@ impl ActorMeshRef { root_comm_actor: &ActorRef, ) -> v1::Result<()> where - A: RemoteHandles + RemoteHandles>, + A: RemoteHandles>, M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor { let cast_mesh_shape = view::Ranked::region(self).into(); @@ -227,6 +232,10 @@ impl ActorMeshRef { Self::with_page_size(name, proc_mesh, DEFAULT_PAGE) } + pub(crate) fn name(&self) -> &Name { + &self.name + } + pub(crate) fn with_page_size(name: Name, proc_mesh: ProcMeshRef, page_size: usize) -> Self { Self { proc_mesh, @@ -237,6 +246,10 @@ impl ActorMeshRef { } } + pub(crate) fn proc_mesh(&self) -> &ProcMeshRef { + &self.proc_mesh + } + #[inline] fn len(&self) -> usize { view::Ranked::region(&self.proc_mesh).num_ranks() diff --git a/hyperactor_mesh/src/v1/proc_mesh.rs b/hyperactor_mesh/src/v1/proc_mesh.rs index 953ca7b5d..4743bda6d 100644 --- a/hyperactor_mesh/src/v1/proc_mesh.rs +++ b/hyperactor_mesh/src/v1/proc_mesh.rs @@ -153,6 +153,10 @@ impl ProcRef { } } + pub(crate) fn proc_id(&self) -> &ProcId { + &self.proc_id + } + pub(crate) fn actor_id(&self, name: &Name) -> ActorId { self.proc_id.actor_id(name.to_string(), 0) } @@ -165,7 +169,6 @@ impl ProcRef { } /// A mesh of processes. -#[allow(dead_code)] #[derive(Debug)] pub struct ProcMesh { name: Name, @@ -370,6 +373,11 @@ impl ProcMesh { ) .await } + + /// Detach the proc mesh from the lifetime of `self`, and return its reference. + pub(crate) fn detach(self) -> ProcMeshRef { + self.current_ref + } } impl Deref for ProcMesh { @@ -491,6 +499,10 @@ impl ProcMeshRef { self.root_comm_actor.as_ref() } + pub(crate) fn name(&self) -> &Name { + &self.name + } + /// The current statuses of procs in this mesh. pub async fn status(&self, cx: &impl context::Actor) -> v1::Result> { let vm: ValueMesh<_> = self.map_into(|proc_ref| { @@ -500,7 +512,7 @@ impl ProcMeshRef { vm.join().await.transpose() } - fn agent_mesh(&self) -> ActorMeshRef { + pub(crate) fn agent_mesh(&self) -> ActorMeshRef { let agent_name = self.ranks.first().unwrap().agent.actor_id().name(); // This name must match the ProcMeshAgent name, which can change depending on the allocator. ActorMeshRef::new(Name::new_reserved(agent_name), self.clone()) diff --git a/monarch_hyperactor/src/code_sync/manager.rs b/monarch_hyperactor/src/code_sync/manager.rs index c15467ae4..f16b6034a 100644 --- a/monarch_hyperactor/src/code_sync/manager.rs +++ b/monarch_hyperactor/src/code_sync/manager.rs @@ -484,6 +484,7 @@ mod tests { use hyperactor_mesh::alloc::Allocator; use hyperactor_mesh::alloc::local::LocalAllocator; use hyperactor_mesh::proc_mesh::ProcMesh; + use hyperactor_mesh::proc_mesh::global_root_client; use ndslice::extent; use ndslice::shape; use tempfile::TempDir; @@ -585,9 +586,13 @@ mod tests { // Create CodeSyncManagerParams let params = CodeSyncManagerParams {}; + // TODO: thread through context, or access the actual python context; + // for now this is basically equivalent (arguably better) to using the proc mesh client. + let instance = global_root_client(); + // Spawn actor mesh with CodeSyncManager actors let actor_mesh = proc_mesh - .spawn::("code_sync_test", ¶ms) + .spawn::(&instance, "code_sync_test", ¶ms) .await?; // Create workspace configuration diff --git a/monarch_hyperactor/src/code_sync/rsync.rs b/monarch_hyperactor/src/code_sync/rsync.rs index 9cb78a4eb..d41c8c668 100644 --- a/monarch_hyperactor/src/code_sync/rsync.rs +++ b/monarch_hyperactor/src/code_sync/rsync.rs @@ -459,6 +459,7 @@ mod tests { use hyperactor_mesh::alloc::Allocator; use hyperactor_mesh::alloc::local::LocalAllocator; use hyperactor_mesh::proc_mesh::ProcMesh; + use hyperactor_mesh::proc_mesh::global_root_client; use ndslice::extent; use tempfile::TempDir; use tokio::fs; @@ -512,8 +513,14 @@ mod tests { // Create RsyncParams - all actors will use the same target workspace for this test let params = RsyncParams {}; + // TODO: thread through context, or access the actual python context; + // for now this is basically equivalent (arguably better) to using the proc mesh client. + let instance = global_root_client(); + // Spawn actor mesh with RsyncActors - let actor_mesh = proc_mesh.spawn::("rsync_test", ¶ms).await?; + let actor_mesh = proc_mesh + .spawn::(&instance, "rsync_test", ¶ms) + .await?; // Test rsync_mesh function - this coordinates rsync operations across the mesh let results = rsync_mesh( diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 479b3dae6..febe387b6 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -15,6 +15,7 @@ use hyperactor::Actor; use hyperactor::RemoteMessage; use hyperactor::WorldId; use hyperactor::actor::RemoteActor; +use hyperactor::context; use hyperactor::context::Mailbox as _; use hyperactor::proc::Instance; use hyperactor::proc::Proc; @@ -24,6 +25,7 @@ use hyperactor_mesh::proc_mesh::ProcEvent; use hyperactor_mesh::proc_mesh::ProcEvents; use hyperactor_mesh::proc_mesh::ProcMesh; use hyperactor_mesh::proc_mesh::SharedSpawnable; +use hyperactor_mesh::proc_mesh::global_root_client; use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellPool; use hyperactor_mesh::shared_cell::SharedCellRef; @@ -89,6 +91,7 @@ impl From for TrackedProcMesh { impl TrackedProcMesh { pub async fn spawn( &self, + cx: &impl context::Actor, actor_name: &str, params: &A::Params, ) -> Result>, anyhow::Error> @@ -96,7 +99,7 @@ impl TrackedProcMesh { A::Params: RemoteMessage, { let mesh = self.cell.borrow()?; - let actor = mesh.spawn(actor_name, params).await?; + let actor = mesh.spawn(cx, actor_name, params).await?; Ok(self.children.insert(actor)) } @@ -313,13 +316,15 @@ impl PyProcMesh { actor: &Bound<'py, PyType>, ) -> PyResult { let unhealthy_event = Arc::clone(&self.unhealthy_event); - let pickled_type = PickledPyObject::pickle(actor.as_any())?; + let pickled_type: PickledPyObject = PickledPyObject::pickle(actor.as_any())?; let proc_mesh = self.try_inner()?; let keepalive = self.keepalive.clone(); let meshimpl = async move { ensure_mesh_healthy(&unhealthy_event).await?; - let instance = proc_mesh.client(); - let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; + // TODO: thread through context, or access the actual python context; + // for now this is basically equivalent (arguably better) to using the proc mesh client. + let instance = global_root_client(); + let actor_mesh = proc_mesh.spawn(&instance, &name, &pickled_type).await?; let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); let im = PythonActorMeshImpl::new( actor_mesh, @@ -354,8 +359,11 @@ impl PyProcMesh { Ok((proc_mesh, pickled_type, unhealthy_event, keepalive)) })?; ensure_mesh_healthy(&unhealthy_event).await?; - let instance = proc_mesh.client(); - let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; + // TODO: thread through context, or access the actual python context; + // for now this is basically equivalent (arguably better) to using the proc mesh client. + let instance = global_root_client(); + + let actor_mesh = proc_mesh.spawn(&instance, &name, &pickled_type).await?; let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new( actor_mesh, diff --git a/ndslice/src/shape.rs b/ndslice/src/shape.rs index a3381b5a8..693a93699 100644 --- a/ndslice/src/shape.rs +++ b/ndslice/src/shape.rs @@ -214,6 +214,14 @@ impl Shape { } } +impl From for Shape { + fn from(region: Region) -> Self { + let (labels, slice) = region.into_inner(); + Shape::new(labels, slice) + .expect("Shape::new should not fail because a Region by definition is a valid Shape") + } +} + impl From<&Region> for Shape { fn from(region: &Region) -> Self { Shape::new(region.labels().to_vec(), region.slice().clone())