diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index c2756401b..45cbcb86c 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -164,7 +164,8 @@ pub type Data = Vec; Deserialize, Named, Clone, - PartialEq + PartialEq, + Eq )] pub enum DeliveryError { /// The destination address is not reachable. diff --git a/hyperactor/src/mailbox/undeliverable.rs b/hyperactor/src/mailbox/undeliverable.rs index cff25d173..9bd46ffcc 100644 --- a/hyperactor/src/mailbox/undeliverable.rs +++ b/hyperactor/src/mailbox/undeliverable.rs @@ -32,6 +32,13 @@ use crate::supervision::ActorSupervisionEvent; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)] pub struct Undeliverable(pub M); +impl Undeliverable { + /// Return the inner M-typed message. + pub fn into_inner(self) -> M { + self.0 + } +} + // Port handle and receiver for undeliverable messages. pub(crate) fn new_undeliverable_port() -> ( PortHandle>, diff --git a/hyperactor_mesh/Cargo.toml b/hyperactor_mesh/Cargo.toml index 775fe7636..ceef42c18 100644 --- a/hyperactor_mesh/Cargo.toml +++ b/hyperactor_mesh/Cargo.toml @@ -75,6 +75,7 @@ serde_bytes = "0.11" serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "raw_value", "unbounded_depth"] } serde_multipart = { version = "0.0.0", path = "../serde_multipart" } strum = { version = "0.27.1", features = ["derive"] } +tempfile = "3.22" thiserror = "2.0.12" tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] } tokio-stream = { version = "0.1.17", features = ["fs", "io-util", "net", "signal", "sync", "time"] } diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index c6231d1a9..1892e35fe 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -15,6 +15,7 @@ use std::future; use std::io; use std::io::Write; use std::os::unix::process::ExitStatusExt; +use std::path::Path; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; @@ -34,6 +35,7 @@ use hyperactor::ProcId; use hyperactor::attrs::Attrs; use hyperactor::channel; use hyperactor::channel::ChannelAddr; +use hyperactor::channel::ChannelError; use hyperactor::channel::ChannelTransport; use hyperactor::channel::Rx; use hyperactor::channel::Tx; @@ -48,10 +50,13 @@ use hyperactor::host::HostError; use hyperactor::host::ProcHandle; use hyperactor::host::ProcManager; use hyperactor::host::TerminateSummary; +use hyperactor::mailbox::IntoBoxedMailboxSender; +use hyperactor::mailbox::MailboxClient; use hyperactor::mailbox::MailboxServer; use hyperactor::proc::Proc; use serde::Deserialize; use serde::Serialize; +use tempfile::TempDir; use tokio::process::Child; use tokio::process::Command; use tokio::sync::oneshot; @@ -64,6 +69,8 @@ use crate::v1; use crate::v1::host_mesh::mesh_agent::HostAgentMode; use crate::v1::host_mesh::mesh_agent::HostMeshAgent; +mod mailbox; + declare_attrs! { /// If enabled (default), bootstrap child processes install /// `PR_SET_PDEATHSIG(SIGKILL)` so the kernel reaps them if the @@ -212,6 +219,10 @@ pub enum Bootstrap { backend_addr: ChannelAddr, /// The callback address used to indicate successful spawning. callback_addr: ChannelAddr, + /// Directory for storing proc socket files. Procs place their sockets + /// in this directory, so that they can be looked up by other procs + /// for direct transfer. + socket_dir_path: PathBuf, /// Optional config snapshot (`hyperactor::config::Attrs`) /// captured by the parent. If present, the child installs it /// as the `Runtime` layer so the parent's effective config @@ -324,6 +335,7 @@ impl Bootstrap { proc_id, backend_addr, callback_addr, + socket_dir_path, config, } => { if let Some(attrs) = config { @@ -343,15 +355,39 @@ impl Bootstrap { eprintln!("(bootstrap) PDEATHSIG disabled via config"); } - let result = - host::spawn_proc(proc_id, backend_addr, callback_addr, |proc| async move { - ProcMeshAgent::boot_v1(proc).await - }) - .await; - match result { - Ok(_proc) => halt().await, - Err(e) => e.into(), - } + let (local_addr, name) = ok!(proc_id + .as_direct() + .ok_or_else(|| anyhow::anyhow!("invalid proc id type: {}", proc_id))); + // TODO provide a direct way to construct these + let serve_addr = format!("unix:{}", socket_dir_path.join(name).display()); + let serve_addr = serve_addr.parse().unwrap(); + + // The following is a modified host::spawn_proc to support direct + // dialing between local procs: 1) we bind each proc to a deterministic + // address in socket_dir_path; 2) we use LocalProcDialer to dial these + // addresses for local procs. + let proc_sender = mailbox::LocalProcDialer::new( + local_addr.clone(), + socket_dir_path, + ok!(MailboxClient::dial(backend_addr)), + ); + + let proc = Proc::new(proc_id.clone(), proc_sender.into_boxed()); + + let agent_handle = ok!(ProcMeshAgent::boot_v1(proc.clone()) + .await + .map_err(|e| HostError::AgentSpawnFailure(proc_id, e))); + + // Finally serve the proc on the same transport as the backend address, + // and call back. + let (proc_addr, proc_rx) = ok!(channel::serve(serve_addr)); + proc.clone().serve(proc_rx); + ok!(ok!(channel::dial(callback_addr)) + .send((proc_addr, agent_handle.bind::())) + .await + .map_err(ChannelError::from)); + + halt().await } Bootstrap::Host { addr, @@ -369,7 +405,7 @@ impl Bootstrap { Some(command) => command, None => ok!(BootstrapCommand::current()), }; - let manager = BootstrapProcManager::new(command); + let manager = BootstrapProcManager::new(command).unwrap(); let (host, _handle) = ok!(Host::serve(manager, addr).await); let addr = host.addr().clone(); let host_mesh_agent = ok!(host @@ -1402,6 +1438,11 @@ pub struct BootstrapProcManager { /// exclusively in the [`Drop`] impl to send `SIGKILL` without /// needing async context. pid_table: Arc>>, + + /// Directory for storing proc socket files. Procs place their sockets + /// in this directory, so that they can be looked up by other procs + /// for direct transfer. + socket_dir: TempDir, } impl Drop for BootstrapProcManager { @@ -1451,12 +1492,13 @@ impl BootstrapProcManager { /// This is the general entry point when you want to manage procs /// backed by a specific binary path (e.g. a bootstrap /// trampoline). - pub(crate) fn new(command: BootstrapCommand) -> Self { - Self { + pub(crate) fn new(command: BootstrapCommand) -> Result { + Ok(Self { command, children: Arc::new(tokio::sync::Mutex::new(HashMap::new())), pid_table: Arc::new(std::sync::Mutex::new(HashMap::new())), - } + socket_dir: tempfile::tempdir()?, + }) } /// The bootstrap command used to launch processes. @@ -1628,6 +1670,7 @@ impl ProcManager for BootstrapProcManager { proc_id: proc_id.clone(), backend_addr, callback_addr, + socket_dir_path: self.socket_dir.path().to_owned(), config: Some(cfg), }; let mut cmd = Command::new(&self.command.program); @@ -2062,6 +2105,7 @@ mod tests { proc_id: id!(foo[0]), backend_addr: ChannelAddr::any(ChannelTransport::Tcp), callback_addr: ChannelAddr::any(ChannelTransport::Unix), + socket_dir_path: PathBuf::from("notexist"), config: None, }, ]; @@ -2119,6 +2163,8 @@ mod tests { attrs[MESH_TAIL_LOG_LINES] = 123; attrs[MESH_BOOTSTRAP_ENABLE_PDEATHSIG] = false; + let socket_dir = tempfile::tempdir().unwrap(); + // Proc case { let original = Bootstrap::Proc { @@ -2126,6 +2172,7 @@ mod tests { backend_addr: ChannelAddr::any(ChannelTransport::Unix), callback_addr: ChannelAddr::any(ChannelTransport::Unix), config: Some(attrs.clone()), + socket_dir_path: socket_dir.path().to_owned(), }; let env_str = original.to_env_safe_string().expect("encode bootstrap"); let decoded = Bootstrap::from_env_safe_string(&env_str).expect("decode bootstrap"); @@ -2165,14 +2212,13 @@ mod tests { use std::process::Stdio; use tokio::process::Command; - use tokio::time::Duration; // Manager; program path is irrelevant for this test. let command = BootstrapCommand { program: PathBuf::from("/bin/true"), ..Default::default() }; - let manager = BootstrapProcManager::new(command); + let manager = BootstrapProcManager::new(command).unwrap(); // Spawn a long-running child process (sleep 30) with // kill_on_drop(true). @@ -2552,7 +2598,7 @@ mod tests { program: PathBuf::from("/bin/true"), ..Default::default() }; - let manager = BootstrapProcManager::new(command); + let manager = BootstrapProcManager::new(command).unwrap(); // Spawn a fast-exiting child. let mut cmd = Command::new("true"); @@ -2586,7 +2632,7 @@ mod tests { program: PathBuf::from("/bin/sleep"), ..Default::default() }; - let manager = BootstrapProcManager::new(command); + let manager = BootstrapProcManager::new(command).unwrap(); // Spawn a process that will live long enough to kill. let mut cmd = Command::new("/bin/sleep"); @@ -2703,7 +2749,8 @@ mod tests { let manager = BootstrapProcManager::new(BootstrapCommand { program: PathBuf::from("/bin/true"), ..Default::default() - }); + }) + .unwrap(); let unknown = ProcId::Direct(ChannelAddr::any(ChannelTransport::Unix), "nope".into()); assert!(manager.status(&unknown).await.is_none()); } @@ -2713,7 +2760,8 @@ mod tests { let manager = BootstrapProcManager::new(BootstrapCommand { program: PathBuf::from("/bin/sleep"), ..Default::default() - }); + }) + .unwrap(); // Long-ish child so it's alive while we "steal" it. let mut cmd = Command::new("/bin/sleep"); @@ -2752,7 +2800,8 @@ mod tests { let manager = BootstrapProcManager::new(BootstrapCommand { program: PathBuf::from("/bin/sleep"), ..Default::default() - }); + }) + .unwrap(); let mut cmd = Command::new("/bin/sleep"); cmd.arg("5").stdout(Stdio::null()).stderr(Stdio::null()); @@ -3105,8 +3154,6 @@ mod tests { instance: &hyperactor::Instance<()>, _tag: &str, ) -> (ProcId, ChannelAddr) { - let proc_id = id!(bootstrap_child[0]); - // Serve a Unix channel as the "backend_addr" and hook it into // this test proc. let (backend_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix)).unwrap(); @@ -3116,6 +3163,9 @@ mod tests { // router. instance.proc().clone().serve(rx); + // We return an arbitrary (but unbound!) unix direct proc id here; + // it is okay, as we're not testing connectivity. + let proc_id = ProcId::Direct(ChannelTransport::Unix.any(), "test".to_string()); (proc_id, backend_addr) } @@ -3127,7 +3177,7 @@ mod tests { .unwrap(); let (instance, _handle) = root.instance("client").unwrap(); - let mgr = BootstrapProcManager::new(BootstrapCommand::test()); + let mgr = BootstrapProcManager::new(BootstrapCommand::test()).unwrap(); let (proc_id, backend_addr) = make_proc_id_and_backend_addr(&instance, "t_term").await; let handle = mgr .spawn(proc_id.clone(), backend_addr.clone()) @@ -3183,7 +3233,7 @@ mod tests { .unwrap(); let (instance, _handle) = root.instance("client").unwrap(); - let mgr = BootstrapProcManager::new(BootstrapCommand::test()); + let mgr = BootstrapProcManager::new(BootstrapCommand::test()).unwrap(); // Proc identity + host backend channel the child will dial. let (proc_id, backend_addr) = make_proc_id_and_backend_addr(&instance, "t_kill").await; @@ -3382,7 +3432,8 @@ mod tests { let manager = BootstrapProcManager::new(BootstrapCommand { program: std::path::PathBuf::from("/bin/true"), // unused in this test ..Default::default() - }); + }) + .unwrap(); manager.spawn_exit_monitor(proc_id.clone(), handle.clone()); // Await terminal status and assert on exit code and stderr diff --git a/hyperactor_mesh/src/bootstrap/mailbox.rs b/hyperactor_mesh/src/bootstrap/mailbox.rs new file mode 100644 index 000000000..6d1fecc0b --- /dev/null +++ b/hyperactor_mesh/src/bootstrap/mailbox.rs @@ -0,0 +1,195 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! This module implements mailbox support for local proc management. + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::RwLock; + +use hyperactor::PortHandle; +use hyperactor::ProcId; +use hyperactor::channel::ChannelAddr; +use hyperactor::channel::ChannelError; +use hyperactor::mailbox::DeliveryError; +use hyperactor::mailbox::MailboxClient; +use hyperactor::mailbox::MailboxSender; +use hyperactor::mailbox::MessageEnvelope; +use hyperactor::mailbox::Undeliverable; + +/// LocalProcDialer dials local procs directly through a configured socket +/// directory. +#[derive(Debug)] +pub(crate) struct LocalProcDialer { + local_addr: ChannelAddr, + socket_dir: PathBuf, + backend_sender: MailboxClient, + local_senders: RwLock>>, +} + +impl LocalProcDialer { + /// Create a new local proc dialer. Any direct-addressed procs with a destination + /// address of `local_addr`, will instead be dialed through the direct sockets + /// present in `socket_dir`. Messages to other procs are forwarded through the + /// backend sender. + pub(crate) fn new( + local_addr: ChannelAddr, + socket_dir: PathBuf, + backend_sender: MailboxClient, + ) -> Self { + Self { + local_addr, + socket_dir, + backend_sender, + local_senders: RwLock::new(HashMap::new()), + } + } +} + +impl MailboxSender for LocalProcDialer { + fn post_unchecked( + &self, + envelope: MessageEnvelope, + return_handle: PortHandle>, + ) { + if let ProcId::Direct(addr, name) = envelope.dest().actor_id().proc_id() + && addr == &self.local_addr + { + let senders = self.local_senders.read().unwrap(); + let senders = if senders.contains_key(name) { + senders + } else { + drop(senders); + let mut senders = self.local_senders.write().unwrap(); + senders.entry(name.clone()).or_insert_with(|| { + let socket_path = self.socket_dir.join(name); + if socket_path.exists() { + let addr = format!("unix:{}", self.socket_dir.join(name).display()); + let addr = addr.parse().unwrap(); + MailboxClient::dial(addr) + } else { + Err(ChannelError::InvalidAddress(format!( + "unix socket path '{}' does not exist", + socket_path.display() + ))) + } + }); + drop(senders); + self.local_senders.read().unwrap() + }; + + match senders.get(name).unwrap() { + Ok(sender) => sender.post_unchecked(envelope, return_handle), + Err(e) => { + let err = DeliveryError::BrokenLink(format!("failed to dial proc: {}", e)); + envelope.undeliverable(err, return_handle); + } + } + } else { + self.backend_sender.post_unchecked(envelope, return_handle); + } + } +} + +#[cfg(test)] +mod tests { + + use std::assert_matches::assert_matches; + + use hyperactor::ActorId; + use hyperactor::Mailbox; + use hyperactor::PortId; + use hyperactor::attrs::Attrs; + use hyperactor::channel::ChannelTransport; + use hyperactor::channel::Rx; + use hyperactor::channel::{self}; + use hyperactor::data::Serialized; + use hyperactor::id; + + use super::*; + + #[tokio::test] + async fn test_proc_dialer() { + let dir = tempfile::tempdir().unwrap(); + let (first_addr, mut first_rx) = channel::serve::( + format!("unix:{}/first", dir.path().display()) + .parse() + .unwrap(), + ) + .unwrap(); + let (second_addr, mut second_rx) = channel::serve::( + format!("unix:{}/second", dir.path().display()) + .parse() + .unwrap(), + ) + .unwrap(); + let (backend_addr, mut backend_rx) = + channel::serve::(ChannelTransport::Unix.any()).unwrap(); + + let local_addr: ChannelAddr = "tcp:3.4.5.6:123".parse().unwrap(); + let first_actor_id = ActorId( + ProcId::Direct(local_addr.clone(), "first".to_string()), + "actor".to_string(), + 0, + ); + let second_actor_id = ActorId( + ProcId::Direct(local_addr.clone(), "second".to_string()), + "actor".to_string(), + 0, + ); + let third_notexist_actor_id = ActorId( + ProcId::Direct(local_addr.clone(), "third".to_string()), + "actor".to_string(), + 0, + ); + let proc_dialer = LocalProcDialer::new( + local_addr.clone(), + dir.path().to_owned(), + MailboxClient::dial(backend_addr).unwrap(), + ); + + let (return_handle, mut return_rx) = + Mailbox::new_detached(id!(world[0].proc)).open_port::>(); + + // Existing address on the host: + let envelope = MessageEnvelope::new( + third_notexist_actor_id.clone(), + PortId(first_actor_id.clone(), 0), + Serialized::serialize(&()).unwrap(), + Attrs::new(), + ); + proc_dialer.post(envelope.clone(), return_handle.clone()); + assert_eq!( + first_rx.recv().await.unwrap().sender(), + &third_notexist_actor_id + ); + + // Nonexistant address on the host: + let envelope = MessageEnvelope::new( + second_actor_id.clone(), + PortId(third_notexist_actor_id.clone(), 0), + Serialized::serialize(&()).unwrap(), + Attrs::new(), + ); + proc_dialer.post(envelope.clone(), return_handle.clone()); + assert_matches!( + &return_rx.recv().await.unwrap().into_inner().errors()[..], + &[DeliveryError::BrokenLink(_)] + ); + + // Outside the host: + let envelope = MessageEnvelope::new( + second_actor_id.clone(), + PortId(id!(external[0].actor), 0), + Serialized::serialize(&()).unwrap(), + Attrs::new(), + ); + proc_dialer.post(envelope.clone(), return_handle.clone()); + assert_eq!(backend_rx.recv().await.unwrap().sender(), &second_actor_id); + } +} diff --git a/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs b/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs index bf0875a4a..1c3ffe914 100644 --- a/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs @@ -293,7 +293,7 @@ impl Actor for HostMeshAgentProcMeshTrampoline { None => BootstrapCommand::current()?, }; tracing::info!("booting host with proc command {:?}", command); - let manager = BootstrapProcManager::new(command); + let manager = BootstrapProcManager::new(command).unwrap(); let (host, _) = Host::serve(manager, transport.any()).await?; HostAgentMode::Process(host) }; @@ -351,7 +351,7 @@ mod tests { #[tokio::test] async fn test_basic() { let (host, _handle) = Host::serve( - BootstrapProcManager::new(BootstrapCommand::test()), + BootstrapProcManager::new(BootstrapCommand::test()).unwrap(), ChannelTransport::Unix.any(), ) .await