Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hyperactor/src/mailbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ pub type Data = Vec<u8>;
Deserialize,
Named,
Clone,
PartialEq
PartialEq,
Eq
)]
pub enum DeliveryError {
/// The destination address is not reachable.
Expand Down
7 changes: 7 additions & 0 deletions hyperactor/src/mailbox/undeliverable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ use crate::supervision::ActorSupervisionEvent;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
pub struct Undeliverable<M: Message>(pub M);

impl<M: Message> Undeliverable<M> {
/// Return the inner M-typed message.
pub fn into_inner(self) -> M {
self.0
}
}

// Port handle and receiver for undeliverable messages.
pub(crate) fn new_undeliverable_port() -> (
PortHandle<Undeliverable<MessageEnvelope>>,
Expand Down
1 change: 1 addition & 0 deletions hyperactor_mesh/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ serde_bytes = "0.11"
serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "raw_value", "unbounded_depth"] }
serde_multipart = { version = "0.0.0", path = "../serde_multipart" }
strum = { version = "0.27.1", features = ["derive"] }
tempfile = "3.22"
thiserror = "2.0.12"
tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] }
tokio-stream = { version = "0.1.17", features = ["fs", "io-util", "net", "signal", "sync", "time"] }
Expand Down
101 changes: 76 additions & 25 deletions hyperactor_mesh/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::future;
use std::io;
use std::io::Write;
use std::os::unix::process::ExitStatusExt;
use std::path::Path;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
Expand All @@ -34,6 +35,7 @@ use hyperactor::ProcId;
use hyperactor::attrs::Attrs;
use hyperactor::channel;
use hyperactor::channel::ChannelAddr;
use hyperactor::channel::ChannelError;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::Rx;
use hyperactor::channel::Tx;
Expand All @@ -48,10 +50,13 @@ use hyperactor::host::HostError;
use hyperactor::host::ProcHandle;
use hyperactor::host::ProcManager;
use hyperactor::host::TerminateSummary;
use hyperactor::mailbox::IntoBoxedMailboxSender;
use hyperactor::mailbox::MailboxClient;
use hyperactor::mailbox::MailboxServer;
use hyperactor::proc::Proc;
use serde::Deserialize;
use serde::Serialize;
use tempfile::TempDir;
use tokio::process::Child;
use tokio::process::Command;
use tokio::sync::oneshot;
Expand All @@ -64,6 +69,8 @@ use crate::v1;
use crate::v1::host_mesh::mesh_agent::HostAgentMode;
use crate::v1::host_mesh::mesh_agent::HostMeshAgent;

mod mailbox;

declare_attrs! {
/// If enabled (default), bootstrap child processes install
/// `PR_SET_PDEATHSIG(SIGKILL)` so the kernel reaps them if the
Expand Down Expand Up @@ -212,6 +219,10 @@ pub enum Bootstrap {
backend_addr: ChannelAddr,
/// The callback address used to indicate successful spawning.
callback_addr: ChannelAddr,
/// Directory for storing proc socket files. Procs place their sockets
/// in this directory, so that they can be looked up by other procs
/// for direct transfer.
socket_dir_path: PathBuf,
/// Optional config snapshot (`hyperactor::config::Attrs`)
/// captured by the parent. If present, the child installs it
/// as the `Runtime` layer so the parent's effective config
Expand Down Expand Up @@ -324,6 +335,7 @@ impl Bootstrap {
proc_id,
backend_addr,
callback_addr,
socket_dir_path,
config,
} => {
if let Some(attrs) = config {
Expand All @@ -343,15 +355,39 @@ impl Bootstrap {
eprintln!("(bootstrap) PDEATHSIG disabled via config");
}

let result =
host::spawn_proc(proc_id, backend_addr, callback_addr, |proc| async move {
ProcMeshAgent::boot_v1(proc).await
})
.await;
match result {
Ok(_proc) => halt().await,
Err(e) => e.into(),
}
let (local_addr, name) = ok!(proc_id
.as_direct()
.ok_or_else(|| anyhow::anyhow!("invalid proc id type: {}", proc_id)));
// TODO provide a direct way to construct these
let serve_addr = format!("unix:{}", socket_dir_path.join(name).display());
let serve_addr = serve_addr.parse().unwrap();

// The following is a modified host::spawn_proc to support direct
// dialing between local procs: 1) we bind each proc to a deterministic
// address in socket_dir_path; 2) we use LocalProcDialer to dial these
// addresses for local procs.
let proc_sender = mailbox::LocalProcDialer::new(
local_addr.clone(),
socket_dir_path,
ok!(MailboxClient::dial(backend_addr)),
);

let proc = Proc::new(proc_id.clone(), proc_sender.into_boxed());

let agent_handle = ok!(ProcMeshAgent::boot_v1(proc.clone())
.await
.map_err(|e| HostError::AgentSpawnFailure(proc_id, e)));

// Finally serve the proc on the same transport as the backend address,
// and call back.
let (proc_addr, proc_rx) = ok!(channel::serve(serve_addr));
proc.clone().serve(proc_rx);
ok!(ok!(channel::dial(callback_addr))
.send((proc_addr, agent_handle.bind::<ProcMeshAgent>()))
.await
.map_err(ChannelError::from));

halt().await
}
Bootstrap::Host {
addr,
Expand All @@ -369,7 +405,7 @@ impl Bootstrap {
Some(command) => command,
None => ok!(BootstrapCommand::current()),
};
let manager = BootstrapProcManager::new(command);
let manager = BootstrapProcManager::new(command).unwrap();
let (host, _handle) = ok!(Host::serve(manager, addr).await);
let addr = host.addr().clone();
let host_mesh_agent = ok!(host
Expand Down Expand Up @@ -1402,6 +1438,11 @@ pub struct BootstrapProcManager {
/// exclusively in the [`Drop`] impl to send `SIGKILL` without
/// needing async context.
pid_table: Arc<std::sync::Mutex<HashMap<ProcId, u32>>>,

/// Directory for storing proc socket files. Procs place their sockets
/// in this directory, so that they can be looked up by other procs
/// for direct transfer.
socket_dir: TempDir,
}

impl Drop for BootstrapProcManager {
Expand Down Expand Up @@ -1451,12 +1492,13 @@ impl BootstrapProcManager {
/// This is the general entry point when you want to manage procs
/// backed by a specific binary path (e.g. a bootstrap
/// trampoline).
pub(crate) fn new(command: BootstrapCommand) -> Self {
Self {
pub(crate) fn new(command: BootstrapCommand) -> Result<Self, io::Error> {
Ok(Self {
command,
children: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
pid_table: Arc::new(std::sync::Mutex::new(HashMap::new())),
}
socket_dir: tempfile::tempdir()?,
})
}

/// The bootstrap command used to launch processes.
Expand Down Expand Up @@ -1628,6 +1670,7 @@ impl ProcManager for BootstrapProcManager {
proc_id: proc_id.clone(),
backend_addr,
callback_addr,
socket_dir_path: self.socket_dir.path().to_owned(),
config: Some(cfg),
};
let mut cmd = Command::new(&self.command.program);
Expand Down Expand Up @@ -2062,6 +2105,7 @@ mod tests {
proc_id: id!(foo[0]),
backend_addr: ChannelAddr::any(ChannelTransport::Tcp),
callback_addr: ChannelAddr::any(ChannelTransport::Unix),
socket_dir_path: PathBuf::from("notexist"),
config: None,
},
];
Expand Down Expand Up @@ -2119,13 +2163,16 @@ mod tests {
attrs[MESH_TAIL_LOG_LINES] = 123;
attrs[MESH_BOOTSTRAP_ENABLE_PDEATHSIG] = false;

let socket_dir = tempfile::tempdir().unwrap();

// Proc case
{
let original = Bootstrap::Proc {
proc_id: id!(foo[42]),
backend_addr: ChannelAddr::any(ChannelTransport::Unix),
callback_addr: ChannelAddr::any(ChannelTransport::Unix),
config: Some(attrs.clone()),
socket_dir_path: socket_dir.path().to_owned(),
};
let env_str = original.to_env_safe_string().expect("encode bootstrap");
let decoded = Bootstrap::from_env_safe_string(&env_str).expect("decode bootstrap");
Expand Down Expand Up @@ -2165,14 +2212,13 @@ mod tests {
use std::process::Stdio;

use tokio::process::Command;
use tokio::time::Duration;

// Manager; program path is irrelevant for this test.
let command = BootstrapCommand {
program: PathBuf::from("/bin/true"),
..Default::default()
};
let manager = BootstrapProcManager::new(command);
let manager = BootstrapProcManager::new(command).unwrap();

// Spawn a long-running child process (sleep 30) with
// kill_on_drop(true).
Expand Down Expand Up @@ -2552,7 +2598,7 @@ mod tests {
program: PathBuf::from("/bin/true"),
..Default::default()
};
let manager = BootstrapProcManager::new(command);
let manager = BootstrapProcManager::new(command).unwrap();

// Spawn a fast-exiting child.
let mut cmd = Command::new("true");
Expand Down Expand Up @@ -2586,7 +2632,7 @@ mod tests {
program: PathBuf::from("/bin/sleep"),
..Default::default()
};
let manager = BootstrapProcManager::new(command);
let manager = BootstrapProcManager::new(command).unwrap();

// Spawn a process that will live long enough to kill.
let mut cmd = Command::new("/bin/sleep");
Expand Down Expand Up @@ -2703,7 +2749,8 @@ mod tests {
let manager = BootstrapProcManager::new(BootstrapCommand {
program: PathBuf::from("/bin/true"),
..Default::default()
});
})
.unwrap();
let unknown = ProcId::Direct(ChannelAddr::any(ChannelTransport::Unix), "nope".into());
assert!(manager.status(&unknown).await.is_none());
}
Expand All @@ -2713,7 +2760,8 @@ mod tests {
let manager = BootstrapProcManager::new(BootstrapCommand {
program: PathBuf::from("/bin/sleep"),
..Default::default()
});
})
.unwrap();

// Long-ish child so it's alive while we "steal" it.
let mut cmd = Command::new("/bin/sleep");
Expand Down Expand Up @@ -2752,7 +2800,8 @@ mod tests {
let manager = BootstrapProcManager::new(BootstrapCommand {
program: PathBuf::from("/bin/sleep"),
..Default::default()
});
})
.unwrap();

let mut cmd = Command::new("/bin/sleep");
cmd.arg("5").stdout(Stdio::null()).stderr(Stdio::null());
Expand Down Expand Up @@ -3105,8 +3154,6 @@ mod tests {
instance: &hyperactor::Instance<()>,
_tag: &str,
) -> (ProcId, ChannelAddr) {
let proc_id = id!(bootstrap_child[0]);

// Serve a Unix channel as the "backend_addr" and hook it into
// this test proc.
let (backend_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix)).unwrap();
Expand All @@ -3116,6 +3163,9 @@ mod tests {
// router.
instance.proc().clone().serve(rx);

// We return an arbitrary (but unbound!) unix direct proc id here;
// it is okay, as we're not testing connectivity.
let proc_id = ProcId::Direct(ChannelTransport::Unix.any(), "test".to_string());
(proc_id, backend_addr)
}

Expand All @@ -3127,7 +3177,7 @@ mod tests {
.unwrap();
let (instance, _handle) = root.instance("client").unwrap();

let mgr = BootstrapProcManager::new(BootstrapCommand::test());
let mgr = BootstrapProcManager::new(BootstrapCommand::test()).unwrap();
let (proc_id, backend_addr) = make_proc_id_and_backend_addr(&instance, "t_term").await;
let handle = mgr
.spawn(proc_id.clone(), backend_addr.clone())
Expand Down Expand Up @@ -3183,7 +3233,7 @@ mod tests {
.unwrap();
let (instance, _handle) = root.instance("client").unwrap();

let mgr = BootstrapProcManager::new(BootstrapCommand::test());
let mgr = BootstrapProcManager::new(BootstrapCommand::test()).unwrap();

// Proc identity + host backend channel the child will dial.
let (proc_id, backend_addr) = make_proc_id_and_backend_addr(&instance, "t_kill").await;
Expand Down Expand Up @@ -3382,7 +3432,8 @@ mod tests {
let manager = BootstrapProcManager::new(BootstrapCommand {
program: std::path::PathBuf::from("/bin/true"), // unused in this test
..Default::default()
});
})
.unwrap();
manager.spawn_exit_monitor(proc_id.clone(), handle.clone());

// Await terminal status and assert on exit code and stderr
Expand Down
Loading
Loading