diff --git a/hyperactor_mesh/benches/main.rs b/hyperactor_mesh/benches/main.rs index 8fd4204b0..499f0da50 100644 --- a/hyperactor_mesh/benches/main.rs +++ b/hyperactor_mesh/benches/main.rs @@ -13,6 +13,7 @@ use criterion::Criterion; use criterion::Throughput; use criterion::criterion_group; use criterion::criterion_main; +use hyperactor::Proc; use hyperactor::channel::ChannelTransport; use hyperactor_mesh::ProcMesh; use hyperactor_mesh::actor_mesh::ActorMesh; @@ -51,9 +52,10 @@ fn bench_actor_scaling(c: &mut Criterion) { .await .unwrap(); + let (bootstrap_instance, _) = Proc::local().instance("bench").unwrap(); let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let actor_mesh: RootActorMesh = proc_mesh - .spawn("bench", &(Duration::from_millis(0))) + .spawn(&bootstrap_instance, "bench", &(Duration::from_millis(0))) .await .unwrap(); let client = proc_mesh.client(); @@ -149,9 +151,10 @@ fn bench_actor_mesh_message_sizes(c: &mut Criterion) { .await .unwrap(); + let (bootstrap_instance, _) = Proc::local().instance("bench").unwrap(); let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let actor_mesh: RootActorMesh = proc_mesh - .spawn("bench", &(Duration::from_millis(0))) + .spawn(&bootstrap_instance, "bench", &(Duration::from_millis(0))) .await .unwrap(); diff --git a/hyperactor_mesh/examples/dining_philosophers.rs b/hyperactor_mesh/examples/dining_philosophers.rs index 502e71176..ee1c74767 100644 --- a/hyperactor_mesh/examples/dining_philosophers.rs +++ b/hyperactor_mesh/examples/dining_philosophers.rs @@ -20,6 +20,7 @@ use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::Proc; use hyperactor::Unbind; use hyperactor::channel::ChannelTransport; use hyperactor_mesh::ProcMesh; @@ -237,10 +238,12 @@ async fn main() -> Result { }) .await?; + let (instance, _) = Proc::local().instance("client").unwrap(); + let proc_mesh = ProcMesh::allocate(alloc).await?; let params = PhilosopherActorParams { size: group_size }; let actor_mesh = proc_mesh - .spawn::("philosopher", ¶ms) + .spawn::(&instance, "philosopher", ¶ms) .await?; let (dining_message_handle, mut dining_message_rx) = proc_mesh.client().open_port(); actor_mesh diff --git a/hyperactor_mesh/examples/sieve.rs b/hyperactor_mesh/examples/sieve.rs index 59982d428..c0f9c8742 100644 --- a/hyperactor_mesh/examples/sieve.rs +++ b/hyperactor_mesh/examples/sieve.rs @@ -22,6 +22,7 @@ use hyperactor::Context; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::Proc; use hyperactor::channel::ChannelTransport; use hyperactor_mesh::Mesh; use hyperactor_mesh::ProcMesh; @@ -116,8 +117,12 @@ async fn main() -> Result { let mesh = ProcMesh::allocate(alloc).await?; + let (instance, _) = Proc::local().instance("client").unwrap(); + let sieve_params = SieveParams { prime: 2 }; - let sieve_mesh = mesh.spawn::("sieve", &sieve_params).await?; + let sieve_mesh = mesh + .spawn::(&instance, "sieve", &sieve_params) + .await?; let sieve_head = sieve_mesh.get(0).unwrap(); let mut primes = vec![2]; diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 8cbf4dbc7..48fef61c8 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -10,6 +10,7 @@ use std::collections::BTreeSet; use std::ops::Deref; +use std::sync::OnceLock; use async_trait::async_trait; use hyperactor::Actor; @@ -39,6 +40,7 @@ use ndslice::Selection; use ndslice::Shape; use ndslice::ShapeError; use ndslice::SliceError; +use ndslice::View; use ndslice::reshape::Limit; use ndslice::reshape::ReshapeError; use ndslice::reshape::ReshapeSliceExt; @@ -47,6 +49,7 @@ use ndslice::selection; use ndslice::selection::EvalOpts; use ndslice::selection::ReifySlice; use ndslice::selection::normal; +use ndslice::view::ViewExt; use serde::Deserialize; use serde::Serialize; use serde_multipart::Part; @@ -62,7 +65,8 @@ use crate::metrics; use crate::proc_mesh::ProcMesh; use crate::reference::ActorMeshId; use crate::reference::ActorMeshRef; -use crate::reference::ProcMeshId; +use crate::sel; +use crate::v1; declare_attrs! { /// Which mesh this message was cast to. Used for undeliverable message @@ -199,9 +203,21 @@ pub trait ActorMesh: Mesh { message: M, ) -> Result<(), CastError> where - Self::Actor: RemoteHandles>, - M: Castable + RemoteMessage, + Self::Actor: RemoteHandles + RemoteHandles>, + M: Castable + RemoteMessage + Clone, { + if let Some(v1) = self.v1() { + if !selection::structurally_equal(&selection, &sel!(*)) { + return Err(CastError::SelectionNotSupported(format!( + "ActorMesh::cast: selection {} not supported; for v1 meshes supports only universal selection", + selection + ))); + } + return v1 + .cast(cx, message) + .map_err(anyhow::Error::from) + .map_err(CastError::from); + } actor_mesh_cast::( cx, // actor context self.id(), // actor mesh id (destination mesh) @@ -224,10 +240,20 @@ pub trait ActorMesh: Mesh { } /// Iterate over all `ActorRef` in this mesh. - fn iter_actor_refs(&self) -> impl Iterator> { + fn iter_actor_refs(&self) -> Box>> { + if let Some(v1) = self.v1() { + // We collect() here to ensure that the data are owned. Since this is a short-lived + // shim, we'll live with it. + return Box::new( + v1.iter() + .map(|(_point, actor_ref)| actor_ref.clone()) + .collect::>() + .into_iter(), + ); + } let gang: GangRef = GangId(self.proc_mesh().world_id().clone(), self.name().to_string()).into(); - self.shape().slice().iter().map(move |rank| gang.rank(rank)) + Box::new(self.shape().slice().iter().map(move |rank| gang.rank(rank))) } async fn stop(&self) -> Result<(), anyhow::Error> { @@ -237,14 +263,14 @@ pub trait ActorMesh: Mesh { /// Get a serializeable reference to this mesh similar to ActorHandle::bind fn bind(&self) -> ActorMeshRef { ActorMeshRef::attest( - ActorMeshId::V0( - ProcMeshId(self.world_id().to_string()), - self.name().to_string(), - ), + self.id(), self.shape().clone(), self.proc_mesh().comm_actor().clone(), ) } + + /// Retrieves the v1 mesh for this v0 ActorMesh, if it is available. + fn v1(&self) -> Option>; } /// Abstracts over shared and borrowed references to a [`ProcMesh`]. @@ -276,12 +302,40 @@ impl Deref for ProcMeshRef<'_> { /// `ActorRef` handles (see `ranks`), and `ActorRef` is only /// defined for `A: Referable`. pub struct RootActorMesh<'a, A: Referable> { - proc_mesh: ProcMeshRef<'a>, - name: String, - pub(crate) ranks: Vec>, // temporary until we remove `ArcActorMesh`. - // The receiver of supervision events. It is None if it has been transferred to - // an actor event observer. - actor_supervision_rx: Option>, + inner: ActorMeshKind<'a, A>, + shape: OnceLock, + proc_mesh: OnceLock, + name: OnceLock, +} + +enum ActorMeshKind<'a, A: Referable> { + V0 { + proc_mesh: ProcMeshRef<'a>, + name: String, + ranks: Vec>, // temporary until we remove `ArcActorMesh`. + // The receiver of supervision events. It is None if it has been transferred to + // an actor event observer. + actor_supervision_rx: Option>, + }, + + V1(v1::ActorMeshRef), +} + +impl<'a, A: Referable> From> for RootActorMesh<'a, A> { + fn from(actor_mesh: v1::ActorMeshRef) -> Self { + Self { + inner: ActorMeshKind::V1(actor_mesh), + shape: OnceLock::new(), + proc_mesh: OnceLock::new(), + name: OnceLock::new(), + } + } +} + +impl<'a, A: Referable> From> for RootActorMesh<'a, A> { + fn from(actor_mesh: v1::ActorMesh) -> Self { + actor_mesh.detach().into() + } } impl<'a, A: Referable> RootActorMesh<'a, A> { @@ -292,10 +346,24 @@ impl<'a, A: Referable> RootActorMesh<'a, A> { ranks: Vec>, ) -> Self { Self { - proc_mesh: ProcMeshRef::Borrowed(proc_mesh), - name, - ranks, - actor_supervision_rx: Some(actor_supervision_rx), + inner: ActorMeshKind::V0 { + proc_mesh: ProcMeshRef::Borrowed(proc_mesh), + name, + ranks, + actor_supervision_rx: Some(actor_supervision_rx), + }, + shape: OnceLock::new(), + proc_mesh: OnceLock::new(), + name: OnceLock::new(), + } + } + + pub(crate) fn new_v1(actor_mesh: v1::ActorMeshRef) -> Self { + Self { + inner: ActorMeshKind::V1(actor_mesh), + shape: OnceLock::new(), + proc_mesh: OnceLock::new(), + name: OnceLock::new(), } } @@ -306,27 +374,50 @@ impl<'a, A: Referable> RootActorMesh<'a, A> { ranks: Vec>, ) -> Self { Self { - proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)), - name, - ranks, - actor_supervision_rx: Some(actor_supervision_rx), + inner: ActorMeshKind::V0 { + proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)), + name, + ranks, + actor_supervision_rx: Some(actor_supervision_rx), + }, + shape: OnceLock::new(), + proc_mesh: OnceLock::new(), + name: OnceLock::new(), } } /// Open a port on this ActorMesh. pub fn open_port(&self) -> (PortHandle, PortReceiver) { - self.proc_mesh.client().open_port() + match &self.inner { + ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh.client().open_port(), + ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"), + } } /// An event stream of actor events. Each RootActorMesh can produce only one such /// stream, returning None after the first call. pub fn events(&mut self) -> Option { - self.actor_supervision_rx - .take() - .map(|actor_supervision_rx| ActorSupervisionEvents { + match &mut self.inner { + ActorMeshKind::V0 { actor_supervision_rx, - mesh_id: self.id(), - }) + .. + } => actor_supervision_rx + .take() + .map(|actor_supervision_rx| ActorSupervisionEvents { + actor_supervision_rx, + mesh_id: self.id(), + }), + ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"), + } + } + + /// Access the ranks field (temporary until we remove `ArcActorMesh`). + #[cfg(test)] + pub(crate) fn ranks(&self) -> &Vec> { + match &self.inner { + ActorMeshKind::V0 { ranks, .. } => ranks, + ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"), + } } } @@ -361,7 +452,10 @@ impl<'a, A: Referable> Mesh for RootActorMesh<'a, A> { 'a: 'b; fn shape(&self) -> &Shape { - self.proc_mesh.shape() + self.shape.get_or_init(|| match &self.inner { + ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh.shape().clone(), + ActorMeshKind::V1(actor_mesh) => actor_mesh.region().into(), + }) } fn select>( @@ -373,11 +467,19 @@ impl<'a, A: Referable> Mesh for RootActorMesh<'a, A> { } fn get(&self, rank: usize) -> Option> { - self.ranks.get(rank).cloned() + match &self.inner { + ActorMeshKind::V0 { ranks, .. } => ranks.get(rank).cloned(), + ActorMeshKind::V1(actor_mesh) => actor_mesh.get(rank), + } } fn id(&self) -> Self::Id { - ActorMeshId::V0(self.proc_mesh.id(), self.name.clone()) + match &self.inner { + ActorMeshKind::V0 { + proc_mesh, name, .. + } => ActorMeshId::V0(proc_mesh.id(), name.clone()), + ActorMeshKind::V1(actor_mesh) => ActorMeshId::V1(actor_mesh.name().clone()), + } } } @@ -385,11 +487,28 @@ impl ActorMesh for RootActorMesh<'_, A> { type Actor = A; fn proc_mesh(&self) -> &ProcMesh { - &self.proc_mesh + match &self.inner { + ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh, + ActorMeshKind::V1(actor_mesh) => self + .proc_mesh + .get_or_init(|| actor_mesh.proc_mesh().clone().into()), + } } fn name(&self) -> &str { - &self.name + match &self.inner { + ActorMeshKind::V0 { name, .. } => name, + ActorMeshKind::V1(actor_mesh) => { + self.name.get_or_init(|| actor_mesh.name().to_string()) + } + } + } + + fn v1(&self) -> Option> { + match &self.inner { + ActorMeshKind::V0 { .. } => None, + ActorMeshKind::V1(actor_mesh) => Some(actor_mesh.clone()), + } } } @@ -439,11 +558,11 @@ impl ActorMesh for SlicedActorMesh<'_, A> { type Actor = A; fn proc_mesh(&self) -> &ProcMesh { - &self.0.proc_mesh + self.0.proc_mesh() } fn name(&self) -> &str { - &self.0.name + self.0.name() } #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`. @@ -462,6 +581,12 @@ impl ActorMesh for SlicedActorMesh<'_, A> { /*root_mesh_shape=*/ self.0.shape(), ) } + + fn v1(&self) -> Option> { + self.0 + .v1() + .map(|actor_mesh| actor_mesh.subset(self.shape().into()).unwrap()) + } } /// The type of error of casting operations. @@ -473,6 +598,9 @@ pub enum CastError { #[error("send on rank {0}: {1}")] MailboxSenderError(usize, MailboxSenderError), + #[error("unsupported selection: {0}")] + SelectionNotSupported(String), + #[error(transparent)] RootMailboxSenderError(#[from] MailboxSenderError), @@ -503,6 +631,7 @@ pub(crate) mod test_util { use anyhow::ensure; use hyperactor::Context; use hyperactor::Handler; + use hyperactor::Instance; use hyperactor::PortRef; use ndslice::extent; @@ -625,8 +754,8 @@ pub(crate) mod test_util { ], )] pub struct ProxyActor { - proc_mesh: Arc, - actor_mesh: RootActorMesh<'static, TestActor>, + proc_mesh: &'static Arc, + actor_mesh: Option>, } impl fmt::Debug for ProxyActor { @@ -664,13 +793,16 @@ pub(crate) mod test_util { .unwrap(); let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap()); let leaked: &'static Arc = Box::leak(Box::new(proc_mesh)); - let actor_mesh: RootActorMesh<'static, TestActor> = - leaked.spawn("echo", &()).await.unwrap(); Ok(Self { - proc_mesh: Arc::clone(leaked), - actor_mesh, + proc_mesh: leaked, + actor_mesh: None, }) } + + async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { + self.actor_mesh = Some(self.proc_mesh.spawn(this, "echo", &()).await?); + Ok(()) + } } #[async_trait] @@ -679,7 +811,7 @@ pub(crate) mod test_util { if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() { // test_proxy_mesh - let actor = self.actor_mesh.get(0).unwrap(); + let actor = self.actor_mesh.as_ref().unwrap().get(0).unwrap(); // For now, we reply directly to the client. // We will support directly wiring up the meshes later. @@ -692,7 +824,7 @@ pub(crate) mod test_util { } else { // test_router_undeliverable_return - let actor: ActorRef<_> = self.actor_mesh.get(0).unwrap(); + let actor: ActorRef<_> = self.actor_mesh.as_ref().unwrap().get(0).unwrap(); let (tx, mut rx) = cx.open_port::(); actor.send(cx, Echo(message.0, tx.bind()))?; @@ -765,8 +897,9 @@ mod tests { }) .await .unwrap(); + let instance = $crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); - let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh.spawn("proxy", &()).await.unwrap(); + let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh.spawn(&instance, "proxy", &()).await.unwrap(); let proxy_actor = actor_mesh.get(0).unwrap(); let (tx, mut rx) = actor_mesh.open_port::(); proxy_actor.send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind())).unwrap(); @@ -790,8 +923,9 @@ mod tests { .await .unwrap(); + let instance = $crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); - let actor_mesh: RootActorMesh = proc_mesh.spawn("echo", &()).await.unwrap(); + let actor_mesh: RootActorMesh = proc_mesh.spawn(&instance, "echo", &()).await.unwrap(); let (reply_handle, mut reply_receiver) = actor_mesh.open_port(); actor_mesh .cast(proc_mesh.client(), sel!(*), Echo("Hello".to_string(), reply_handle.bind())) @@ -816,12 +950,13 @@ mod tests { }) .await .unwrap(); + let instance = $crate::v1::testing::instance().await; let mesh = ProcMesh::allocate(alloc).await.unwrap(); let (undeliverable_msg_tx, _) = mesh.client().open_port(); let ping_pong_actor_params = PingPongActorParams::new(Some(undeliverable_msg_tx.bind()), None); let actor_mesh: RootActorMesh = mesh - .spawn::("ping-pong", &ping_pong_actor_params) + .spawn::(&instance, "ping-pong", &ping_pong_actor_params) .await .unwrap(); @@ -854,10 +989,11 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let (undeliverable_tx, _undeliverable_rx) = proc_mesh.client().open_port(); let params = PingPongActorParams::new(Some(undeliverable_tx.bind()), None); - let actor_mesh = proc_mesh.spawn::("pingpong", ¶ms).await.unwrap(); + let actor_mesh = proc_mesh.spawn::(&instance, "pingpong", ¶ms).await.unwrap(); let slice = actor_mesh.shape().slice(); let mut futures = Vec::new(); @@ -899,8 +1035,9 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); - let actor_mesh: RootActorMesh = proc_mesh.spawn("echo", &()).await.unwrap(); + let actor_mesh: RootActorMesh = proc_mesh.spawn(&instance, "echo", &()).await.unwrap(); let dont_simulate_error = true; let (reply_handle, mut reply_receiver) = actor_mesh.open_port(); actor_mesh @@ -942,8 +1079,9 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); - let actor_mesh: RootActorMesh = proc_mesh.spawn("echo", &()).await.unwrap(); + let actor_mesh: RootActorMesh = proc_mesh.spawn(&instance, "echo", &()).await.unwrap(); // Bounce the message through all actors and return it to the sender (us). let mut hops: VecDeque<_> = actor_mesh.iter().map(|actor| actor.port()).collect(); @@ -963,6 +1101,7 @@ mod tests { #[tokio::test] async fn test_inter_proc_mesh_comms() { let mut meshes = Vec::new(); + let instance = crate::v1::testing::instance().await; for _ in 0..2 { let alloc = $allocator .allocate(AllocSpec { @@ -976,7 +1115,7 @@ mod tests { let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap()); let proc_mesh_clone = Arc::clone(&proc_mesh); - let actor_mesh : RootActorMesh = proc_mesh_clone.spawn("echo", &()).await.unwrap(); + let actor_mesh : RootActorMesh = proc_mesh_clone.spawn(&instance, "echo", &()).await.unwrap(); meshes.push((proc_mesh, actor_mesh)); } @@ -1023,11 +1162,12 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client()); let params = CastTestActorParams{ forward_port: tx.bind() }; - let actor_mesh: RootActorMesh = proc_mesh.spawn("actor", ¶ms).await.unwrap(); + let actor_mesh: RootActorMesh = proc_mesh.spawn(&instance, "actor", ¶ms).await.unwrap(); actor_mesh.cast(proc_mesh.client(), sel!(*), CastTestMessage::Forward("abc".to_string())).unwrap(); @@ -1085,11 +1225,12 @@ mod tests { .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mesh = ProcMesh::allocate(alloc).await.unwrap(); let (reply_port_handle, mut reply_port_receiver) = mesh.client().open_port::(); let reply_port = reply_port_handle.bind(); - let actor_mesh: RootActorMesh = mesh.spawn("test", &()).await.unwrap(); + let actor_mesh: RootActorMesh = mesh.spawn(&instance, "test", &()).await.unwrap(); let actor_ref = actor_mesh.get(0).unwrap(); let mut headers = Attrs::new(); set_cast_info_on_headers(&mut headers, extent.point_of_rank(0).unwrap(), mesh.client().self_id().clone()); @@ -1149,6 +1290,7 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let monkey = alloc.chaos_monkey(); let mut mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut events = mesh.events().unwrap(); @@ -1158,7 +1300,7 @@ mod tests { None, ); let actor_mesh: RootActorMesh = mesh - .spawn::("ping-pong", &ping_pong_actor_params) + .spawn::(&instance, "ping-pong", &ping_pong_actor_params) .await .unwrap(); @@ -1218,13 +1360,14 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let stop = alloc.stopper(); let mut mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut events = mesh.events().unwrap(); let actor_mesh = mesh - .spawn::("reply-then-fail", &()) + .spawn::(&instance, "reply-then-fail", &()) .await .unwrap(); @@ -1286,6 +1429,7 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mesh = ProcMesh::allocate(alloc).await.unwrap(); let ping_pong_actor_params = PingPongActorParams::new( @@ -1293,12 +1437,12 @@ mod tests { None, ); let mesh_one: RootActorMesh = mesh - .spawn::("mesh_one", &ping_pong_actor_params) + .spawn::(&instance, "mesh_one", &ping_pong_actor_params) .await .unwrap(); let mesh_two: RootActorMesh = mesh - .spawn::("mesh_two", &ping_pong_actor_params) + .spawn::(&instance, "mesh_two", &ping_pong_actor_params) .await .unwrap(); @@ -1398,10 +1542,11 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut proc_events = proc_mesh.events().unwrap(); let actor_mesh: RootActorMesh = - proc_mesh.spawn("ingest", &()).await.unwrap(); + proc_mesh.spawn(&instance, "ingest", &()).await.unwrap(); let (reply_handle, mut reply_receiver) = actor_mesh.open_port(); let dest = actor_mesh.get(0).unwrap(); @@ -1486,10 +1631,11 @@ mod tests { // SAFETY: Not multithread safe. unsafe { std::env::set_var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK", "1") }; + let instance = crate::v1::testing::instance().await; let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut proc_events = proc_mesh.events().unwrap(); let mut actor_mesh: RootActorMesh<'_, ProxyActor> = - { proc_mesh.spawn("proxy", &()).await.unwrap() }; + { proc_mesh.spawn(&instance, "proxy", &()).await.unwrap() }; let mut actor_events = actor_mesh.events().unwrap(); let proxy_actor = actor_mesh.get(0).unwrap(); @@ -1508,7 +1654,7 @@ mod tests { ); assert_eq!( actor_events.next().await.unwrap().actor_id.name(), - &actor_mesh.name + actor_mesh.name(), ); } } @@ -1644,12 +1790,13 @@ mod tests { transport: ChannelTransport::Local })) .unwrap(); + let instance = runtime.block_on(crate::v1::testing::instance()); let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap(); let addr = ChannelAddr::any(ChannelTransport::Unix); let actor_mesh: RootActorMesh = - runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap(); + runtime.block_on(proc_mesh.spawn(&instance, "echo", &addr)).unwrap(); let mut runner = TestRunner::default(); let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0) @@ -1676,12 +1823,13 @@ mod tests { transport: ChannelTransport::Local })) .unwrap(); + let instance = runtime.block_on(crate::v1::testing::instance()); let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap(); let addr = ChannelAddr::any(ChannelTransport::Unix); let actor_mesh: RootActorMesh = - runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap(); + runtime.block_on(proc_mesh.spawn(&instance, "echo", &addr)).unwrap(); let first_label = extent.labels().first().unwrap(); @@ -1749,12 +1897,13 @@ mod tests { transport: ChannelTransport::Local })) .unwrap(); + let instance = runtime.block_on(crate::v1::testing::instance()); let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap(); let addr = ChannelAddr::any(ChannelTransport::Unix); let actor_mesh: RootActorMesh = - runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap(); + runtime.block_on(proc_mesh.spawn(&instance, "echo", &addr)).unwrap(); let mut runner = TestRunner::default(); let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0) @@ -1771,4 +1920,54 @@ mod tests { } } } + + mod shim { + use std::collections::HashSet; + + use hyperactor::context::Mailbox; + use ndslice::Extent; + use ndslice::extent; + + use super::*; + use crate::sel; + + #[tokio::test] + async fn test_basic() { + let instance = v1::testing::instance().await; + let host_mesh = v1::testing::host_mesh(extent!(host = 4)).await; + let proc_mesh = host_mesh + .spawn(instance, "test", Extent::unity()) + .await + .unwrap(); + let actor_mesh = proc_mesh + .spawn::(instance, "test", &()) + .await + .unwrap(); + + let actor_mesh_v0: RootActorMesh<'_, _> = actor_mesh.clone().into(); + + let (cast_info, mut cast_info_rx) = instance.mailbox().open_port(); + actor_mesh_v0 + .cast( + instance, + sel!(*), + v1::testactor::GetCastInfo { + cast_info: cast_info.bind(), + }, + ) + .unwrap(); + + let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect(); + while !point_to_actor.is_empty() { + let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap(); + let key = (point, origin_actor_ref); + assert!( + point_to_actor.remove(&key), + "key {:?} not present or removed twice", + key + ); + assert_eq!(&sender_actor_id, instance.self_id()); + } + } + } } diff --git a/hyperactor_mesh/src/alloc/sim.rs b/hyperactor_mesh/src/alloc/sim.rs index 652347331..983425f6f 100644 --- a/hyperactor_mesh/src/alloc/sim.rs +++ b/hyperactor_mesh/src/alloc/sim.rs @@ -199,11 +199,13 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let handle = hyperactor::simnet::simnet_handle().unwrap(); - let actor_mesh: RootActorMesh = proc_mesh.spawn("echo", &()).await.unwrap(); + let actor_mesh: RootActorMesh = + proc_mesh.spawn(&instance, "echo", &()).await.unwrap(); let actors = actor_mesh.iter_actor_refs().collect::>(); assert_eq!( handle.sample_latency( diff --git a/hyperactor_mesh/src/comm.rs b/hyperactor_mesh/src/comm.rs index 57cc2369f..3d1283d48 100644 --- a/hyperactor_mesh/src/comm.rs +++ b/hyperactor_mesh/src/comm.rs @@ -830,9 +830,9 @@ mod tests { let params = TestActorParams { forward_port: tx.bind(), }; - let actor_mesh = proc_mesh - .clone() - .spawn::(dest_actor_name, ¶ms) + let instance = crate::v1::testing::instance().await; + let actor_mesh = Arc::clone(&proc_mesh) + .spawn::(&instance, dest_actor_name, ¶ms) .await .unwrap(); @@ -968,7 +968,7 @@ mod tests { } = setup_mesh::(None).await; let proc_mesh_client = actor_mesh.proc_mesh().client(); - let ranks = actor_mesh.ranks.clone(); + let ranks = actor_mesh.ranks().clone(); execute_cast_and_reply(ranks, proc_mesh_client, reply1_rx, reply2_rx, reply_tos).await; } @@ -1032,7 +1032,7 @@ mod tests { .. } = setup_mesh(Some(accum::sum::())).await; let proc_mesh_client = actor_mesh.proc_mesh().client(); - let ranks = actor_mesh.ranks.clone(); + let ranks = actor_mesh.ranks().clone(); execute_cast_and_accum(ranks, proc_mesh_client, reply1_rx, reply_tos).await; } diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 5ad6a136e..c06831014 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -51,6 +51,8 @@ use hyperactor::supervision::ActorSupervisionEvent; use ndslice::Range; use ndslice::Shape; use ndslice::ShapeError; +use ndslice::View; +use ndslice::ViewExt; use strum::AsRefStr; use tokio::sync::mpsc; use tracing::Instrument; @@ -76,6 +78,7 @@ use crate::proc_mesh::mesh_agent::update_event_actor_id; use crate::reference::ProcMeshId; use crate::router; use crate::shortuuid::ShortUuid; +use crate::v1; pub mod mesh_agent; @@ -205,20 +208,30 @@ pub fn global_root_client() -> &'static Instance<()> { } type ActorEventRouter = Arc>>; + /// A ProcMesh maintains a mesh of procs whose lifecycles are managed by /// an allocator. pub struct ProcMesh { - // The underlying set of events. It is None if it has been transferred to - // a proc event observer. - event_state: Option, - actor_event_router: ActorEventRouter, - shape: Shape, - ranks: Vec<(ShortUuid, ProcId, ChannelAddr, ActorRef)>, - #[allow(dead_code)] // will be used in subsequent diff - client_proc: Proc, - client: Instance<()>, - comm_actors: Vec>, - world_id: WorldId, + inner: ProcMeshKind, + shape: OnceLock, +} + +enum ProcMeshKind { + V0 { + // The underlying set of events. It is None if it has been transferred to + // a proc event observer. + event_state: Option, + actor_event_router: ActorEventRouter, + shape: Shape, + ranks: Vec<(ShortUuid, ProcId, ChannelAddr, ActorRef)>, + #[allow(dead_code)] // will be used in subsequent diff + client_proc: Proc, + client: Instance<()>, + comm_actors: Vec>, + world_id: WorldId, + }, + + V1(v1::ProcMeshRef), } struct EventState { @@ -226,6 +239,15 @@ struct EventState { supervision_events: PortReceiver, } +impl From for ProcMesh { + fn from(proc_mesh: v1::ProcMeshRef) -> Self { + ProcMesh { + inner: ProcMeshKind::V1(proc_mesh), + shape: OnceLock::new(), + } + } +} + impl ProcMesh { #[hyperactor::instrument(fields(name = "ProcMesh::allocate"))] pub async fn allocate( @@ -440,27 +462,30 @@ impl ProcMesh { ); Ok(Self { - event_state: Some(EventState { - alloc, - supervision_events, - }), - actor_event_router: Arc::new(DashMap::new()), - shape, - ranks: running - .into_iter() - .map( - |AllocatedProc { - create_key, - proc_id, - addr, - mesh_agent, - }| (create_key, proc_id, addr, mesh_agent), - ) - .collect(), - client_proc, - client, - comm_actors, - world_id, + inner: ProcMeshKind::V0 { + event_state: Some(EventState { + alloc, + supervision_events, + }), + actor_event_router: Arc::new(DashMap::new()), + shape, + ranks: running + .into_iter() + .map( + |AllocatedProc { + create_key, + proc_id, + addr, + mesh_agent, + }| (create_key, proc_id, addr, mesh_agent), + ) + .collect(), + client_proc, + client, + comm_actors, + world_id, + }, + shape: OnceLock::new(), }) } @@ -533,13 +558,35 @@ impl ProcMesh { .collect()) } - fn agents(&self) -> impl Iterator> + '_ { - self.ranks.iter().map(|(_, _, _, agent)| agent.clone()) + fn agents(&self) -> Box> + '_ + Send> { + match &self.inner { + ProcMeshKind::V0 { ranks, .. } => { + Box::new(ranks.iter().map(|(_, _, _, agent)| agent.clone())) + } + ProcMeshKind::V1(proc_mesh) => Box::new( + proc_mesh + .agent_mesh() + .iter() + .map(|(_point, agent)| agent.clone()) + // We need to collect here so that we can return an iterator + // that fully owns the data and does not reference temporary + // values. + // + // Because this is a shim that we expect to be short-lived, + // we'll leave this hack as is; a proper solution here would + // be to have implement an owning iterator (into_iter) for views. + .collect::>() + .into_iter(), + ), + } } /// Return the comm actor to which casts should be forwarded. pub(crate) fn comm_actor(&self) -> &ActorRef { - &self.comm_actors[0] + match &self.inner { + ProcMeshKind::V0 { comm_actors, .. } => &comm_actors[0], + ProcMeshKind::V1(proc_mesh) => proc_mesh.root_comm_actor().unwrap(), + } } /// Spawn an `ActorMesh` by launching the same actor type on all @@ -559,102 +606,145 @@ impl ProcMesh { /// cross proc boundaries when launching each actor. pub async fn spawn( &self, + cx: &impl context::Actor, actor_name: &str, params: &A::Params, ) -> Result, anyhow::Error> where A::Params: RemoteMessage, { - let (tx, rx) = mpsc::unbounded_channel::(); - { - // Instantiate supervision routing BEFORE spawning the actor mesh. - self.actor_event_router.insert(actor_name.to_string(), tx); - tracing::info!( - name = "router_insert", - actor_name = %actor_name, - "the length of the router is {}", self.actor_event_router.len(), - ); + match &self.inner { + ProcMeshKind::V0 { + actor_event_router, + client, + .. + } => { + let (tx, rx) = mpsc::unbounded_channel::(); + { + // Instantiate supervision routing BEFORE spawning the actor mesh. + actor_event_router.insert(actor_name.to_string(), tx); + tracing::info!( + name = "router_insert", + actor_name = %actor_name, + "the length of the router is {}", actor_event_router.len(), + ); + } + let root_mesh = RootActorMesh::new( + self, + actor_name.to_string(), + rx, + Self::spawn_on_procs::(client, self.agents(), actor_name, params).await?, + ); + Ok(root_mesh) + } + ProcMeshKind::V1(proc_mesh) => { + let actor_mesh = proc_mesh.spawn(cx, actor_name, params).await?; + Ok(RootActorMesh::new_v1(actor_mesh.detach())) + } } - let root_mesh = RootActorMesh::new( - self, - actor_name.to_string(), - rx, - Self::spawn_on_procs::(&self.client, self.agents(), actor_name, params).await?, - ); - Ok(root_mesh) } /// A client actor used to communicate with any member of this mesh. pub fn client(&self) -> &Instance<()> { - &self.client + match &self.inner { + ProcMeshKind::V0 { client, .. } => client, + ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client for v1::ProcMesh"), + } } pub fn client_proc(&self) -> &Proc { - &self.client_proc + match &self.inner { + ProcMeshKind::V0 { client_proc, .. } => client_proc, + ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client proc for v1::ProcMesh"), + } } pub fn proc_id(&self) -> &ProcId { - self.client_proc.proc_id() + self.client_proc().proc_id() } pub fn world_id(&self) -> &WorldId { - &self.world_id + match &self.inner { + ProcMeshKind::V0 { world_id, .. } => world_id, + ProcMeshKind::V1(_proc_mesh) => unimplemented!("no world_id for v1::ProcMesh"), + } } /// An event stream of proc events. Each ProcMesh can produce only one such /// stream, returning None after the first call. pub fn events(&mut self) -> Option { - self.event_state.take().map(|event_state| ProcEvents { - event_state, - ranks: self - .ranks - .iter() - .enumerate() - .map(|(rank, (create_key, proc_id, _addr, _mesh_agent))| { - (proc_id.clone(), (rank, create_key.clone())) - }) - .collect(), - actor_event_router: self.actor_event_router.clone(), - }) + match &mut self.inner { + ProcMeshKind::V0 { + event_state, + ranks, + actor_event_router, + .. + } => event_state.take().map(|event_state| ProcEvents { + event_state, + ranks: ranks + .iter() + .enumerate() + .map(|(rank, (create_key, proc_id, _addr, _mesh_agent))| { + (proc_id.clone(), (rank, create_key.clone())) + }) + .collect(), + actor_event_router: actor_event_router.clone(), + }), + ProcMeshKind::V1(_proc_mesh) => todo!(), + } } pub fn shape(&self) -> &Shape { - &self.shape + // We store the shape here, only because it isn't materialized in + // V1 meshes. + self.shape.get_or_init(|| match &self.inner { + ProcMeshKind::V0 { shape, .. } => shape.clone(), + ProcMeshKind::V1(proc_mesh) => proc_mesh.region().into(), + }) } /// Send stop actors message to all mesh agents for a specific mesh name #[hyperactor::observe_result("ProcMesh")] pub async fn stop_actor_by_name(&self, mesh_name: &str) -> Result<(), anyhow::Error> { - let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT); - let results = join_all(self.agents().map(|agent| async move { - let actor_id = ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0); - ( - actor_id.clone(), - agent - .clone() - .stop_actor(&self.client, actor_id, timeout.as_millis() as u64) - .await, - ) - })) - .await; - - for (actor_id, result) in results { - match result { - Ok(StopActorResult::Timeout) => { - tracing::warn!("timed out while stopping actor {}", actor_id); - } - Ok(StopActorResult::NotFound) => { - tracing::warn!("no actor {} on proc {}", actor_id, actor_id.proc_id()); - } - Ok(StopActorResult::Success) => { - tracing::info!("stopped actor {}", actor_id); - } - Err(e) => { - tracing::warn!("error stopping actor {}: {}", actor_id, e); + match &self.inner { + ProcMeshKind::V0 { client, .. } => { + let timeout = + hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT); + let results = join_all(self.agents().map(|agent| async move { + let actor_id = + ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0); + ( + actor_id.clone(), + agent + .clone() + .stop_actor(client, actor_id, timeout.as_millis() as u64) + .await, + ) + })) + .await; + + for (actor_id, result) in results { + match result { + Ok(StopActorResult::Timeout) => { + tracing::warn!("timed out while stopping actor {}", actor_id); + } + Ok(StopActorResult::NotFound) => { + tracing::warn!("no actor {} on proc {}", actor_id, actor_id.proc_id()); + } + Ok(StopActorResult::Success) => { + tracing::info!("stopped actor {}", actor_id); + } + Err(e) => { + tracing::warn!("error stopping actor {}: {}", actor_id, e); + } + } } + Ok(()) + } + ProcMeshKind::V1(_proc_mesh) => { + anyhow::bail!("kill actor by name unsupported for v1::ProcMesh") } } - Ok(()) } } @@ -855,6 +945,7 @@ pub trait SharedSpawnable { // `Referable`: so we can hand back ActorRef in RootActorMesh async fn spawn( self, + cx: &impl context::Actor, actor_name: &str, params: &A::Params, ) -> Result, anyhow::Error> @@ -868,30 +959,41 @@ impl + Send + Sync + 'static> SharedSpawnable for D // `Referable`: so we can hand back ActorRef in RootActorMesh async fn spawn( self, + cx: &impl context::Actor, actor_name: &str, params: &A::Params, ) -> Result, anyhow::Error> where A::Params: RemoteMessage, { - let (tx, rx) = mpsc::unbounded_channel::(); - { - // Instantiate supervision routing BEFORE spawning the actor mesh. - self.actor_event_router.insert(actor_name.to_string(), tx); - tracing::info!( - name = "router_insert", - actor_name = %actor_name, - "the length of the router is {}", self.actor_event_router.len(), - ); + match &self.deref().inner { + ProcMeshKind::V0 { + actor_event_router, + client, + .. + } => { + let (tx, rx) = mpsc::unbounded_channel::(); + { + // Instantiate supervision routing BEFORE spawning the actor mesh. + actor_event_router.insert(actor_name.to_string(), tx); + tracing::info!( + name = "router_insert", + actor_name = %actor_name, + "the length of the router is {}", actor_event_router.len(), + ); + } + let ranks = + ProcMesh::spawn_on_procs::(client, self.agents(), actor_name, params) + .await?; + Ok(RootActorMesh::new_shared( + self, + actor_name.to_string(), + rx, + ranks, + )) + } + ProcMeshKind::V1(proc_mesh) => todo!(), } - let ranks = - ProcMesh::spawn_on_procs::(&self.client, self.agents(), actor_name, params).await?; - Ok(RootActorMesh::new_shared( - self, - actor_name.to_string(), - rx, - ranks, - )) } } @@ -902,7 +1004,7 @@ impl Mesh for ProcMesh { type Sliced<'a> = SlicedProcMesh<'a>; fn shape(&self) -> &Shape { - &self.shape + ProcMesh::shape(self) } fn select>( @@ -914,11 +1016,17 @@ impl Mesh for ProcMesh { } fn get(&self, rank: usize) -> Option { - Some(self.ranks[rank].1.clone()) + match &self.inner { + ProcMeshKind::V0 { ranks, .. } => Some(ranks[rank].1.clone()), + ProcMeshKind::V1(proc_mesh) => proc_mesh.get(rank).map(|proc| proc.proc_id().clone()), + } } fn id(&self) -> Self::Id { - ProcMeshId(self.world_id().name().to_string()) + match &self.inner { + ProcMeshKind::V0 { world_id, .. } => ProcMeshId(world_id.name().to_string()), + ProcMeshKind::V1(proc_mesh) => ProcMeshId(proc_mesh.name().to_string()), + } } } @@ -930,13 +1038,22 @@ impl fmt::Display for ProcMesh { impl fmt::Debug for ProcMesh { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ProcMesh") - .field("shape", &self.shape()) - .field("ranks", &self.ranks) - .field("client_proc", &self.client_proc) - .field("client", &"") - // Skip the alloc field since it doesn't implement Debug - .finish() + match &self.inner { + ProcMeshKind::V0 { + shape, + ranks, + client_proc, + .. + } => f + .debug_struct("ProcMesh::V0") + .field("shape", shape) + .field("ranks", ranks) + .field("client_proc", client_proc) + .field("client", &"") + // Skip the alloc field since it doesn't implement Debug + .finish(), + ProcMeshKind::V1(proc_mesh) => fmt::Debug::fmt(proc_mesh, f), + } } } @@ -1056,7 +1173,12 @@ mod tests { let mut mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut events = mesh.events().unwrap(); - let mut actors = mesh.spawn::("failing", &()).await.unwrap(); + let instance = crate::v1::testing::instance().await; + + let mut actors = mesh + .spawn::(&instance, "failing", &()) + .await + .unwrap(); let mut actor_events = actors.events().unwrap(); actors @@ -1106,8 +1228,66 @@ mod tests { .unwrap(); let mesh = ProcMesh::allocate(alloc).await.unwrap(); - mesh.spawn::("dup", &()).await.unwrap(); - let result = mesh.spawn::("dup", &()).await; + let instance = crate::v1::testing::instance().await; + mesh.spawn::(&instance, "dup", &()) + .await + .unwrap(); + let result = mesh.spawn::(&instance, "dup", &()).await; assert!(result.is_err()); } + + mod shim { + use std::collections::HashSet; + + use hyperactor::context::Mailbox; + use ndslice::Extent; + use ndslice::Selection; + + use super::*; + use crate::sel; + + #[tokio::test] + async fn test_basic() { + let instance = v1::testing::instance().await; + let ext = extent!(host = 4); + let host_mesh = v1::testing::host_mesh(ext.clone()).await; + let proc_mesh = host_mesh + .spawn(instance, "test", Extent::unity()) + .await + .unwrap(); + let proc_mesh_v0: ProcMesh = proc_mesh.detach().into(); + + let actor_mesh = proc_mesh_v0 + .spawn::(instance, "test", &()) + .await + .unwrap(); + + let (cast_info, mut cast_info_rx) = instance.mailbox().open_port(); + actor_mesh + .cast( + instance, + sel!(*), + v1::testactor::GetCastInfo { + cast_info: cast_info.bind(), + }, + ) + .unwrap(); + + let mut point_to_actor: HashSet<_> = actor_mesh + .iter_actor_refs() + .enumerate() + .map(|(rank, actor_ref)| (ext.point_of_rank(rank).unwrap(), actor_ref)) + .collect(); + while !point_to_actor.is_empty() { + let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap(); + let key = (point, origin_actor_ref); + assert!( + point_to_actor.remove(&key), + "key {:?} not present or removed twice", + key + ); + assert_eq!(&sender_actor_id, instance.self_id()); + } + } + } } diff --git a/hyperactor_mesh/src/reference.rs b/hyperactor_mesh/src/reference.rs index 222edd38a..7768825fe 100644 --- a/hyperactor_mesh/src/reference.rs +++ b/hyperactor_mesh/src/reference.rs @@ -336,9 +336,11 @@ mod tests { }) .await .unwrap(); + let instance = crate::v1::testing::instance().await; let ping_proc_mesh = ProcMesh::allocate(alloc_ping).await.unwrap(); let ping_mesh: RootActorMesh = ping_proc_mesh .spawn( + &instance, "ping", &MeshPingPongActorParams { mesh_id: ActorMeshId::V0( @@ -356,6 +358,7 @@ mod tests { let pong_proc_mesh = ProcMesh::allocate(alloc_pong).await.unwrap(); let pong_mesh: RootActorMesh = pong_proc_mesh .spawn( + &instance, "pong", &MeshPingPongActorParams { mesh_id: ActorMeshId::V0( diff --git a/hyperactor_mesh/src/v1/actor_mesh.rs b/hyperactor_mesh/src/v1/actor_mesh.rs index e06c68204..0710883a3 100644 --- a/hyperactor_mesh/src/v1/actor_mesh.rs +++ b/hyperactor_mesh/src/v1/actor_mesh.rs @@ -74,6 +74,11 @@ impl ActorMesh { pub fn name(&self) -> &Name { &self.name } + + /// Detach this mesh from the lifetime of `self`, and return its reference. + pub(crate) fn detach(self) -> ActorMeshRef { + self.current_ref + } } impl Deref for ActorMesh { @@ -139,7 +144,7 @@ pub struct ActorMeshRef { _phantom: PhantomData, } -impl ActorMeshRef { +impl ActorMeshRef { /// Cast a message to all actors in this mesh. pub fn cast(&self, cx: &impl context::Actor, message: M) -> v1::Result<()> where @@ -186,7 +191,7 @@ impl ActorMeshRef { root_comm_actor: &ActorRef, ) -> v1::Result<()> where - A: RemoteHandles + RemoteHandles>, + A: RemoteHandles>, M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor { let cast_mesh_shape = view::Ranked::region(self).into(); @@ -231,6 +236,10 @@ impl ActorMeshRef { Self::with_page_size(name, proc_mesh, DEFAULT_PAGE) } + pub(crate) fn name(&self) -> &Name { + &self.name + } + pub(crate) fn with_page_size(name: Name, proc_mesh: ProcMeshRef, page_size: usize) -> Self { Self { proc_mesh, @@ -241,6 +250,10 @@ impl ActorMeshRef { } } + pub(crate) fn proc_mesh(&self) -> &ProcMeshRef { + &self.proc_mesh + } + #[inline] fn len(&self) -> usize { view::Ranked::region(&self.proc_mesh).num_ranks() diff --git a/hyperactor_mesh/src/v1/proc_mesh.rs b/hyperactor_mesh/src/v1/proc_mesh.rs index 333202e58..b27f97e7a 100644 --- a/hyperactor_mesh/src/v1/proc_mesh.rs +++ b/hyperactor_mesh/src/v1/proc_mesh.rs @@ -156,6 +156,10 @@ impl ProcRef { } } + pub(crate) fn proc_id(&self) -> &ProcId { + &self.proc_id + } + pub(crate) fn actor_id(&self, name: &Name) -> ActorId { self.proc_id.actor_id(name.to_string(), 0) } @@ -168,7 +172,6 @@ impl ProcRef { } /// A mesh of processes. -#[allow(dead_code)] #[derive(Debug)] pub struct ProcMesh { name: Name, @@ -372,6 +375,11 @@ impl ProcMesh { ) .await } + + /// Detach the proc mesh from the lifetime of `self`, and return its reference. + pub(crate) fn detach(self) -> ProcMeshRef { + self.current_ref + } } impl Deref for ProcMesh { @@ -493,6 +501,10 @@ impl ProcMeshRef { self.root_comm_actor.as_ref() } + pub(crate) fn name(&self) -> &Name { + &self.name + } + /// The current statuses of procs in this mesh. pub async fn status(&self, cx: &impl context::Actor) -> v1::Result> { let vm: ValueMesh<_> = self.map_into(|proc_ref| { @@ -502,7 +514,7 @@ impl ProcMeshRef { vm.join().await.transpose() } - fn agent_mesh(&self) -> ActorMeshRef { + pub(crate) fn agent_mesh(&self) -> ActorMeshRef { let agent_name = self.ranks.first().unwrap().agent.actor_id().name(); // This name must match the ProcMeshAgent name, which can change depending on the allocator. ActorMeshRef::new(Name::new_reserved(agent_name), self.clone()) diff --git a/hyperactor_mesh/test/hyperactor_mesh_proxy_test.rs b/hyperactor_mesh/test/hyperactor_mesh_proxy_test.rs index f72a90cc2..4d9a78a3d 100644 --- a/hyperactor_mesh/test/hyperactor_mesh_proxy_test.rs +++ b/hyperactor_mesh/test/hyperactor_mesh_proxy_test.rs @@ -17,8 +17,10 @@ use clap::Parser; use hyperactor::Actor; use hyperactor::Context; use hyperactor::Handler; +use hyperactor::Instance; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::Proc; use hyperactor::channel::ChannelTransport; use hyperactor_mesh::Mesh; use hyperactor_mesh::ProcMesh; @@ -91,9 +93,8 @@ impl Handler for TestActor { ], )] pub struct ProxyActor { - #[allow(dead_code)] - proc_mesh: Arc, - actor_mesh: RootActorMesh<'static, TestActor>, + proc_mesh: &'static Arc, + actor_mesh: Option>, } impl fmt::Debug for ProxyActor { @@ -126,19 +127,22 @@ impl Actor for ProxyActor { .unwrap(); let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap()); let leaked: &'static Arc = Box::leak(Box::new(proc_mesh)); - let actor_mesh: RootActorMesh<'static, TestActor> = - leaked.spawn("echo", &()).await.unwrap(); Ok(Self { - proc_mesh: Arc::clone(leaked), - actor_mesh, + proc_mesh: leaked, + actor_mesh: None, }) } + + async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { + self.actor_mesh = Some(self.proc_mesh.spawn(this, "echo", &()).await.unwrap()); + Ok(()) + } } #[async_trait] impl Handler for ProxyActor { async fn handle(&mut self, cx: &Context, message: Echo) -> Result<(), anyhow::Error> { - let actor = self.actor_mesh.get(0).unwrap(); + let actor = self.actor_mesh.as_ref().unwrap().get(0).unwrap(); let (tx, mut rx) = cx.open_port(); actor.send(cx, Echo(message.0, tx.bind()))?; @@ -163,9 +167,11 @@ async fn run_client(exe_path: PathBuf, keep_alive: bool) -> Result<(), anyhow::E .await .unwrap(); + let (instance, _) = Proc::local().instance("client").unwrap(); + let mut proc_mesh = ProcMesh::allocate(alloc).await?; let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh - .spawn("proxy", &exe_path.to_str().unwrap().to_string()) + .spawn(&instance, "proxy", &exe_path.to_str().unwrap().to_string()) .await?; let proxy_actor = actor_mesh.get(0).unwrap(); let (tx, mut rx) = actor_mesh.open_port::(); diff --git a/monarch_extension/src/code_sync.rs b/monarch_extension/src/code_sync.rs index c1d52f734..11a556d3d 100644 --- a/monarch_extension/src/code_sync.rs +++ b/monarch_extension/src/code_sync.rs @@ -23,6 +23,8 @@ use monarch_hyperactor::code_sync::manager::CodeSyncMethod; use monarch_hyperactor::code_sync::manager::WorkspaceConfig; use monarch_hyperactor::code_sync::manager::WorkspaceShape; use monarch_hyperactor::code_sync::manager::code_sync_mesh; +use monarch_hyperactor::context::PyInstance; +use monarch_hyperactor::instance_dispatch; use monarch_hyperactor::proc_mesh::PyProcMesh; use monarch_hyperactor::runtime::signal_safe_block_on; use pyo3::Bound; @@ -261,13 +263,15 @@ impl CodeSyncMeshClient { #[pymethods] impl CodeSyncMeshClient { #[staticmethod] - #[pyo3(signature = (*, proc_mesh))] - fn spawn_blocking(py: Python, proc_mesh: &PyProcMesh) -> PyResult { + #[pyo3(signature = (*, client, proc_mesh))] + fn spawn_blocking(py: Python, client: PyInstance, proc_mesh: &PyProcMesh) -> PyResult { let proc_mesh = proc_mesh.try_inner()?; signal_safe_block_on(py, async move { - let actor_mesh = proc_mesh - .spawn("code_sync_manager", &CodeSyncManagerParams {}) - .await?; + let actor_mesh = instance_dispatch!(client, |cx| { + proc_mesh + .spawn(cx, "code_sync_manager", &CodeSyncManagerParams {}) + .await? + }); Ok(Self { actor_mesh }) })? } diff --git a/monarch_extension/src/logging.rs b/monarch_extension/src/logging.rs index 758658874..11c5ab587 100644 --- a/monarch_extension/src/logging.rs +++ b/monarch_extension/src/logging.rs @@ -22,6 +22,7 @@ use hyperactor_mesh::logging::LogForwardMessage; use hyperactor_mesh::selection::Selection; use hyperactor_mesh::shared_cell::SharedCell; use monarch_hyperactor::context::PyInstance; +use monarch_hyperactor::instance_dispatch; use monarch_hyperactor::logging::LoggerRuntimeActor; use monarch_hyperactor::logging::LoggerRuntimeMessage; use monarch_hyperactor::proc_mesh::PyProcMesh; @@ -86,13 +87,18 @@ impl LoggingMeshClient { #[pymethods] impl LoggingMeshClient { #[staticmethod] - fn spawn(_instance: &PyInstance, proc_mesh: &PyProcMesh) -> PyResult { + fn spawn(instance: PyInstance, proc_mesh: &PyProcMesh) -> PyResult { let proc_mesh = proc_mesh.try_inner()?; PyPythonTask::new(async move { let client_actor = proc_mesh.client_proc().spawn("log_client", ()).await?; let client_actor_ref = client_actor.bind(); - let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?; - let logger_mesh = proc_mesh.spawn("logger", &()).await?; + let forwarder_mesh = instance_dispatch!(instance, |cx| { + proc_mesh + .spawn(cx, "log_forwarder", &client_actor_ref) + .await? + }); + let logger_mesh = + instance_dispatch!(instance, |cx| { proc_mesh.spawn(cx, "logger", &()).await? }); // Register flush_internal as a on-stop callback let client_actor_for_callback = client_actor.clone(); diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 7ebdab54d..214219f88 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -760,7 +760,7 @@ impl Actor for MeshControllerActor { }; let workers = proc_mesh - .spawn(&format!("tensor_engine_workers_{}", self.id), ¶m) + .spawn(this, &format!("tensor_engine_workers_{}", self.id), ¶m) .await?; workers.borrow().unwrap().cast( this, @@ -770,7 +770,7 @@ impl Actor for MeshControllerActor { self.workers = Some(workers); let brokers = proc_mesh - .spawn(&format!("tensor_engine_brokers_{}", self.id), &()) + .spawn(this, &format!("tensor_engine_brokers_{}", self.id), &()) .await?; self.brokers = Some(brokers); Ok(()) diff --git a/monarch_hyperactor/src/code_sync/manager.rs b/monarch_hyperactor/src/code_sync/manager.rs index c15467ae4..f16b6034a 100644 --- a/monarch_hyperactor/src/code_sync/manager.rs +++ b/monarch_hyperactor/src/code_sync/manager.rs @@ -484,6 +484,7 @@ mod tests { use hyperactor_mesh::alloc::Allocator; use hyperactor_mesh::alloc::local::LocalAllocator; use hyperactor_mesh::proc_mesh::ProcMesh; + use hyperactor_mesh::proc_mesh::global_root_client; use ndslice::extent; use ndslice::shape; use tempfile::TempDir; @@ -585,9 +586,13 @@ mod tests { // Create CodeSyncManagerParams let params = CodeSyncManagerParams {}; + // TODO: thread through context, or access the actual python context; + // for now this is basically equivalent (arguably better) to using the proc mesh client. + let instance = global_root_client(); + // Spawn actor mesh with CodeSyncManager actors let actor_mesh = proc_mesh - .spawn::("code_sync_test", ¶ms) + .spawn::(&instance, "code_sync_test", ¶ms) .await?; // Create workspace configuration diff --git a/monarch_hyperactor/src/code_sync/rsync.rs b/monarch_hyperactor/src/code_sync/rsync.rs index 9cb78a4eb..d41c8c668 100644 --- a/monarch_hyperactor/src/code_sync/rsync.rs +++ b/monarch_hyperactor/src/code_sync/rsync.rs @@ -459,6 +459,7 @@ mod tests { use hyperactor_mesh::alloc::Allocator; use hyperactor_mesh::alloc::local::LocalAllocator; use hyperactor_mesh::proc_mesh::ProcMesh; + use hyperactor_mesh::proc_mesh::global_root_client; use ndslice::extent; use tempfile::TempDir; use tokio::fs; @@ -512,8 +513,14 @@ mod tests { // Create RsyncParams - all actors will use the same target workspace for this test let params = RsyncParams {}; + // TODO: thread through context, or access the actual python context; + // for now this is basically equivalent (arguably better) to using the proc mesh client. + let instance = global_root_client(); + // Spawn actor mesh with RsyncActors - let actor_mesh = proc_mesh.spawn::("rsync_test", ¶ms).await?; + let actor_mesh = proc_mesh + .spawn::(&instance, "rsync_test", ¶ms) + .await?; // Test rsync_mesh function - this coordinates rsync operations across the mesh let results = rsync_mesh( diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 3b80da45c..c1ce40b99 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -15,6 +15,7 @@ use hyperactor::Actor; use hyperactor::RemoteMessage; use hyperactor::WorldId; use hyperactor::actor::Referable; +use hyperactor::context; use hyperactor::context::Mailbox as _; use hyperactor::proc::Instance; use hyperactor::proc::Proc; @@ -24,6 +25,7 @@ use hyperactor_mesh::proc_mesh::ProcEvent; use hyperactor_mesh::proc_mesh::ProcEvents; use hyperactor_mesh::proc_mesh::ProcMesh; use hyperactor_mesh::proc_mesh::SharedSpawnable; +use hyperactor_mesh::proc_mesh::global_root_client; use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellPool; use hyperactor_mesh::shared_cell::SharedCellRef; @@ -89,6 +91,7 @@ impl From for TrackedProcMesh { impl TrackedProcMesh { pub async fn spawn( &self, + cx: &impl context::Actor, actor_name: &str, params: &A::Params, ) -> Result>, anyhow::Error> @@ -96,7 +99,7 @@ impl TrackedProcMesh { A::Params: RemoteMessage, { let mesh = self.cell.borrow()?; - let actor = mesh.spawn(actor_name, params).await?; + let actor = mesh.spawn(cx, actor_name, params).await?; Ok(self.children.insert(actor)) } @@ -313,13 +316,15 @@ impl PyProcMesh { actor: &Bound<'py, PyType>, ) -> PyResult { let unhealthy_event = Arc::clone(&self.unhealthy_event); - let pickled_type = PickledPyObject::pickle(actor.as_any())?; + let pickled_type: PickledPyObject = PickledPyObject::pickle(actor.as_any())?; let proc_mesh = self.try_inner()?; let keepalive = self.keepalive.clone(); let meshimpl = async move { ensure_mesh_healthy(&unhealthy_event).await?; - let instance = proc_mesh.client(); - let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; + // TODO: thread through context, or access the actual python context; + // for now this is basically equivalent (arguably better) to using the proc mesh client. + let instance = global_root_client(); + let actor_mesh = proc_mesh.spawn(&instance, &name, &pickled_type).await?; let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); let im = PythonActorMeshImpl::new( actor_mesh, @@ -354,8 +359,11 @@ impl PyProcMesh { Ok((proc_mesh, pickled_type, unhealthy_event, keepalive)) })?; ensure_mesh_healthy(&unhealthy_event).await?; - let instance = proc_mesh.client(); - let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; + // TODO: thread through context, or access the actual python context; + // for now this is basically equivalent (arguably better) to using the proc mesh client. + let instance = global_root_client(); + + let actor_mesh = proc_mesh.spawn(&instance, &name, &pickled_type).await?; let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new( actor_mesh, diff --git a/monarch_hyperactor/tests/code_sync/auto_reload.rs b/monarch_hyperactor/tests/code_sync/auto_reload.rs index a958cbb0f..689cb45b6 100644 --- a/monarch_hyperactor/tests/code_sync/auto_reload.rs +++ b/monarch_hyperactor/tests/code_sync/auto_reload.rs @@ -8,6 +8,7 @@ use anyhow::Result; use anyhow::anyhow; +use hyperactor::Proc; use hyperactor::channel::ChannelTransport; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::alloc::AllocSpec; @@ -53,10 +54,12 @@ CONSTANT = "initial_constant" }) .await?; + let (instance, _) = Proc::local().instance("client").unwrap(); + let proc_mesh = ProcMesh::allocate(alloc).await?; let params = AutoReloadParams {}; let actor_mesh = proc_mesh - .spawn::("auto_reload_test", ¶ms) + .spawn::(&instance, "auto_reload_test", ¶ms) .await?; // Get a reference to the single actor diff --git a/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs b/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs index ef0b7c753..2a65d02a1 100644 --- a/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs +++ b/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs @@ -66,6 +66,7 @@ use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortRef; +use hyperactor::Proc; use hyperactor::Unbind; use hyperactor::channel::ChannelTransport; use hyperactor::supervision::ActorSupervisionEvent; @@ -681,6 +682,8 @@ pub async fn run() -> Result<(), anyhow::Error> { device_2_ibv_config = IbverbsConfig::default(); } + let (instance, _) = Proc::local().instance("test").unwrap(); + // Create process allocator for spawning actors let mut alloc = ProcessAllocator::new(Command::new( buck_resources::get("monarch/monarch_rdma/examples/cuda_ping_pong/bootstrap").unwrap(), @@ -713,12 +716,20 @@ pub async fn run() -> Result<(), anyhow::Error> { // Create RDMA manager for the first device let device_1_rdma_manager: RootActorMesh<'_, RdmaManagerActor> = device_1_proc_mesh - .spawn("device_1_rdma_manager", &Some(device_1_ibv_config)) + .spawn( + &instance, + "device_1_rdma_manager", + &Some(device_1_ibv_config), + ) .await?; // Create RDMA manager for the second device let device_2_rdma_manager: RootActorMesh<'_, RdmaManagerActor> = device_2_proc_mesh - .spawn("device_2_rdma_manager", &Some(device_2_ibv_config)) + .spawn( + &instance, + "device_2_rdma_manager", + &Some(device_2_ibv_config), + ) .await?; // Get the RDMA manager actor references @@ -728,6 +739,7 @@ pub async fn run() -> Result<(), anyhow::Error> { // Create the CUDA RDMA actors let device_1_actor_mesh: RootActorMesh<'_, CudaRdmaActor> = device_1_proc_mesh .spawn( + &instance, "device_1_actor", &(device_1_rdma_manager_ref.clone(), 0, config.buffer_size), ) @@ -735,6 +747,7 @@ pub async fn run() -> Result<(), anyhow::Error> { let device_2_actor_mesh: RootActorMesh<'_, CudaRdmaActor> = device_2_proc_mesh .spawn( + &instance, "device_2_actor", &(device_2_rdma_manager_ref.clone(), 1, config.buffer_size), ) diff --git a/monarch_rdma/examples/parameter_server/src/parameter_server.rs b/monarch_rdma/examples/parameter_server/src/parameter_server.rs index 39ff992f0..659e32fc3 100644 --- a/monarch_rdma/examples/parameter_server/src/parameter_server.rs +++ b/monarch_rdma/examples/parameter_server/src/parameter_server.rs @@ -65,6 +65,7 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::PortRef; +use hyperactor::Proc; use hyperactor::Unbind; use hyperactor::channel::ChannelTransport; use hyperactor::context::Mailbox as _; @@ -467,6 +468,8 @@ pub async fn run(num_workers: usize, num_steps: usize) -> Result<(), anyhow::Err // As normal, create a proc mesh for the parameter server. tracing::info!("creating parameter server proc mesh..."); + let (instance, _) = Proc::local().instance("client").unwrap(); + let mut alloc = ProcessAllocator::new(Command::new( buck_resources::get("monarch/monarch_rdma/examples/parameter_server/bootstrap").unwrap(), )); @@ -493,7 +496,7 @@ pub async fn run(num_workers: usize, num_steps: usize) -> Result<(), anyhow::Err // We spin this up manually here, but in Python-land we assume this will // be spun up with the PyProcMesh. let ps_rdma_manager: RootActorMesh<'_, RdmaManagerActor> = ps_proc_mesh - .spawn("ps_rdma_manager", &Some(ps_ibv_config)) + .spawn(&instance, "ps_rdma_manager", &Some(ps_ibv_config)) .await .unwrap(); @@ -517,13 +520,14 @@ pub async fn run(num_workers: usize, num_steps: usize) -> Result<(), anyhow::Err ); // Similarly, create an RdmaManagerActor corresponding to each worker. let worker_rdma_manager_mesh: RootActorMesh<'_, RdmaManagerActor> = worker_proc_mesh - .spawn("ps_rdma_manager", &Some(worker_ibv_config)) + .spawn(&instance, "ps_rdma_manager", &Some(worker_ibv_config)) .await .unwrap(); tracing::info!("spawning parameter server"); let ps_actor_mesh: RootActorMesh<'_, ParameterServerActor> = ps_proc_mesh .spawn( + &instance, "parameter_server", &(ps_rdma_manager.iter().next().unwrap(), num_workers), ) @@ -534,8 +538,10 @@ pub async fn run(num_workers: usize, num_steps: usize) -> Result<(), anyhow::Err let ps_actor = ps_actor_mesh.iter().next().unwrap(); tracing::info!("spawning worker actors"); - let worker_actor_mesh: RootActorMesh<'_, WorkerActor> = - worker_proc_mesh.spawn("worker_actors", &()).await.unwrap(); + let worker_actor_mesh: RootActorMesh<'_, WorkerActor> = worker_proc_mesh + .spawn(&instance, "worker_actors", &()) + .await + .unwrap(); let worker_rdma_managers: Vec> = worker_rdma_manager_mesh.iter().collect(); diff --git a/monarch_rdma/extension/lib.rs b/monarch_rdma/extension/lib.rs index f2d330f81..f2c1f747c 100644 --- a/monarch_rdma/extension/lib.rs +++ b/monarch_rdma/extension/lib.rs @@ -287,9 +287,10 @@ impl PyRdmaManager { /// Creates an RDMA manager actor on the given ProcMesh (async version). /// Returns the actor mesh if RDMA is supported, None otherwise. #[classmethod] - fn create_rdma_manager_nonblocking<'py>( + fn create_rdma_manager_nonblocking( _cls: &Bound<'_, PyType>, proc_mesh: &PyProcMesh, + client: PyInstance, ) -> PyResult { tracing::debug!("spawning RDMA manager on target proc_mesh nodes"); @@ -298,12 +299,14 @@ impl PyRdmaManager { PyPythonTask::new(async move { // Spawns the `RdmaManagerActor` on the target proc_mesh. // This allows the `RdmaController` to run on any node while real RDMA operations occur on appropriate hardware. - let actor_mesh = tracked_proc_mesh - // Pass None to use default config - RdmaManagerActor will use default IbverbsConfig - // TODO - make IbverbsConfig configurable - .spawn::("rdma_manager", &None) - .await - .map_err(|err| PyException::new_err(err.to_string()))?; + let actor_mesh = instance_dispatch!(client, |cx| { + tracked_proc_mesh + // Pass None to use default config - RdmaManagerActor will use default IbverbsConfig + // TODO - make IbverbsConfig configurable + .spawn::(cx, "rdma_manager", &None) + .await + .map_err(|err| PyException::new_err(err.to_string()))? + }); // Use placeholder device name since actual device is determined on remote node Ok(Some(PyRdmaManager { diff --git a/monarch_rdma/src/test_utils.rs b/monarch_rdma/src/test_utils.rs index e09348205..630bd9605 100644 --- a/monarch_rdma/src/test_utils.rs +++ b/monarch_rdma/src/test_utils.rs @@ -79,6 +79,7 @@ pub mod test_utils { use hyperactor::ActorRef; use hyperactor::Instance; + use hyperactor::Proc; use hyperactor::channel::ChannelTransport; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; @@ -326,9 +327,11 @@ pub mod test_utils { .await .unwrap(); + let (instance, _) = Proc::local().instance("test").unwrap(); + let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap())); let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_1 - .spawn("rdma_manager", &Some(config1)) + .spawn(&instance, "rdma_manager", &Some(config1)) .await .unwrap(); @@ -344,7 +347,7 @@ pub mod test_utils { let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap())); let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_2 - .spawn("rdma_manager", &Some(config2)) + .spawn(&instance, "rdma_manager", &Some(config2)) .await .unwrap(); @@ -409,8 +412,8 @@ pub mod test_utils { 0, 0, )); - assert!(dptr as usize % granularity == 0); - assert!(padded_size % granularity == 0); + assert!((dptr as usize).is_multiple_of(granularity)); + assert!(padded_size.is_multiple_of(granularity)); // fails if a add cu_check macro; but passes if we don't let err = cuda_sys::cuMemMap( diff --git a/ndslice/src/shape.rs b/ndslice/src/shape.rs index a3381b5a8..693a93699 100644 --- a/ndslice/src/shape.rs +++ b/ndslice/src/shape.rs @@ -214,6 +214,14 @@ impl Shape { } } +impl From for Shape { + fn from(region: Region) -> Self { + let (labels, slice) = region.into_inner(); + Shape::new(labels, slice) + .expect("Shape::new should not fail because a Region by definition is a valid Shape") + } +} + impl From<&Region> for Shape { fn from(region: &Region) -> Self { Shape::new(region.labels().to_vec(), region.slice().clone()) diff --git a/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi b/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi index 255067fc4..4b91e3492 100644 --- a/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi @@ -7,7 +7,7 @@ # pyre-unsafe from pathlib import Path -from typing import Dict, final +from typing import Any, Dict, final from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh @@ -83,6 +83,7 @@ class CodeSyncMeshClient: """ @staticmethod def spawn_blocking( + client: Any, proc_mesh: ProcMesh, ) -> CodeSyncMeshClient: ... async def sync_workspace( diff --git a/python/monarch/_rust_bindings/rdma.pyi b/python/monarch/_rust_bindings/rdma.pyi index f7e6402af..0f1f4b57f 100644 --- a/python/monarch/_rust_bindings/rdma.pyi +++ b/python/monarch/_rust_bindings/rdma.pyi @@ -21,6 +21,7 @@ class _RdmaManager: def create_rdma_manager_nonblocking( self, proc_mesh: Any, + client: Any, ) -> PythonTask[_RdmaManager | None]: ... @final diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 3ec5941a1..277f2e57f 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -466,6 +466,7 @@ async def sync_workspace( """ if self._code_sync_client is None: self._code_sync_client = CodeSyncMeshClient.spawn_blocking( + client=context().actor_instance, proc_mesh=await self._proc_mesh_for_asyncio_fixme, ) diff --git a/python/monarch/_src/rdma/rdma.py b/python/monarch/_src/rdma/rdma.py index bc0577fbc..15c797e58 100644 --- a/python/monarch/_src/rdma/rdma.py +++ b/python/monarch/_src/rdma/rdma.py @@ -132,7 +132,9 @@ async def init_rdma_on_mesh(self, proc_mesh: ProcMesh) -> None: async def create_manager() -> _RdmaManager: proc_mesh_result = await Future(coro=proc_mesh._proc_mesh.task()) return none_throws( - await _RdmaManager.create_rdma_manager_nonblocking(proc_mesh_result) + await _RdmaManager.create_rdma_manager_nonblocking( + proc_mesh_result, context().actor_instance + ) ) self._manager_futures[proc_mesh] = Future(coro=create_manager()) diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 600948a90..a98c37eb9 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -66,7 +66,7 @@ from monarch.actor import ProcMesh from monarch._rust_bindings.monarch_hyperactor.shape import Point -from monarch._src.actor.actor_mesh import context, Instance +from monarch._src.actor.actor_mesh import context, Context, Instance from monarch._src.actor.device_utils import _local_device_count from monarch.common.client import Client diff --git a/python/tests/test_rdma_unsupported.py b/python/tests/test_rdma_unsupported.py index cd6092bfb..236b3337f 100644 --- a/python/tests/test_rdma_unsupported.py +++ b/python/tests/test_rdma_unsupported.py @@ -35,6 +35,7 @@ async def test_rdma_manager_creation_fails_when_unsupported(): ibverbs_supported() function that calls ibv_get_device_list() in the C library. """ from monarch._rust_bindings.rdma import _RdmaManager + from monarch._src.actor.actor_mesh import context from monarch._src.actor.future import Future from monarch.actor import this_host @@ -43,7 +44,8 @@ async def test_rdma_manager_creation_fails_when_unsupported(): with pytest.raises(Exception) as exc_info: await Future( coro=_RdmaManager.create_rdma_manager_nonblocking( - await Future(coro=proc_mesh._proc_mesh.task()) + await Future(coro=proc_mesh._proc_mesh.task()), + context().actor_instance, ) )