Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion hyper/src/commands/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::time::Duration;

use hyperactor::channel::ChannelAddr;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::TcpMode;
use hyperactor_multiprocess::system::System;

// The commands in the demo spawn temporary actors the join a system.
Expand All @@ -27,7 +28,9 @@ pub struct ServeCommand {

impl ServeCommand {
pub async fn run(self) -> anyhow::Result<()> {
let addr = self.addr.unwrap_or(ChannelAddr::any(ChannelTransport::Tcp));
let addr = self
.addr
.unwrap_or(ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)));
let handle = System::serve(addr, LONG_DURATION, LONG_DURATION).await?;
eprintln!("serve: {}", handle.local_addr());
handle.await;
Expand Down
5 changes: 3 additions & 2 deletions hyperactor/benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use hyperactor::channel;
use hyperactor::channel::ChannelAddr;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::Rx;
use hyperactor::channel::TcpMode;
use hyperactor::channel::Tx;
use hyperactor::channel::dial;
use hyperactor::channel::serve;
Expand Down Expand Up @@ -66,7 +67,7 @@ impl Message {
fn bench_message_sizes(c: &mut Criterion) {
let transports = vec![
("local", ChannelTransport::Local),
("tcp", ChannelTransport::Tcp),
("tcp", ChannelTransport::Tcp(TcpMode::Hostname)),
("unix", ChannelTransport::Unix),
];

Expand Down Expand Up @@ -108,7 +109,7 @@ fn bench_message_rates(c: &mut Criterion) {

let transports = vec![
("local", ChannelTransport::Local),
("tcp", ChannelTransport::Tcp),
("tcp", ChannelTransport::Tcp(TcpMode::Hostname)),
("unix", ChannelTransport::Unix),
//TODO Add TLS once it is able to run in Sandcastle
];
Expand Down
80 changes: 62 additions & 18 deletions hyperactor/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
use core::net::SocketAddr;
use std::fmt;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
#[cfg(target_os = "linux")]
use std::os::linux::net::SocketAddrExt;
use std::str::FromStr;
Expand Down Expand Up @@ -236,6 +238,26 @@ impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
}
}

/// The hostname to use for TLS connections.
#[derive(
Clone,
Debug,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
strum::EnumIter,
strum::Display,
strum::EnumString
)]
pub enum TcpMode {
/// Use localhost/loopback for the connection.
Localhost,
/// Use host domain name for the connection.
Hostname,
}

/// The hostname to use for TLS connections.
#[derive(
Clone,
Expand Down Expand Up @@ -315,7 +337,7 @@ impl fmt::Display for MetaTlsAddr {
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)]
pub enum ChannelTransport {
/// Transport over a TCP connection.
Tcp,
Tcp(TcpMode),

/// Transport over a TCP connection with TLS support within Meta
MetaTls(TlsMode),
Expand All @@ -333,7 +355,7 @@ pub enum ChannelTransport {
impl fmt::Display for ChannelTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Tcp => write!(f, "tcp"),
Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
Self::Local => write!(f, "local"),
Self::Sim(transport) => write!(f, "sim({})", transport),
Expand All @@ -358,7 +380,13 @@ impl FromStr for ChannelTransport {
}

match s {
"tcp" => Ok(ChannelTransport::Tcp),
// Default to TcpMode::Hostname, if the mode isn't set
"tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
s if s.starts_with("tcp(") => {
let inner = &s["tcp(".len()..s.len() - 1];
let mode = inner.parse()?;
Ok(ChannelTransport::Tcp(mode))
}
"local" => Ok(ChannelTransport::Local),
"unix" => Ok(ChannelTransport::Unix),
s if s.starts_with("metatls(") && s.ends_with(")") => {
Expand All @@ -373,9 +401,10 @@ impl FromStr for ChannelTransport {

impl ChannelTransport {
/// All known channel transports.
pub fn all() -> [ChannelTransport; 3] {
pub fn all() -> [ChannelTransport; 4] {
[
ChannelTransport::Tcp,
ChannelTransport::Tcp(TcpMode::Localhost),
ChannelTransport::Tcp(TcpMode::Hostname),
ChannelTransport::Local,
ChannelTransport::Unix,
// TODO add MetaTls (T208303369)
Expand All @@ -392,7 +421,7 @@ impl ChannelTransport {
/// Returns true if this transport type represents a remote channel.
pub fn is_remote(&self) -> bool {
match self {
ChannelTransport::Tcp => true,
ChannelTransport::Tcp(_) => true,
ChannelTransport::MetaTls(_) => true,
ChannelTransport::Local => false,
ChannelTransport::Sim(_) => false,
Expand Down Expand Up @@ -502,18 +531,27 @@ impl ChannelAddr {
/// servers to "any" address.
pub fn any(transport: ChannelTransport) -> Self {
match transport {
ChannelTransport::Tcp => {
let ip = hostname::get()
.ok()
.and_then(|hostname| {
// TODO: Avoid using DNS directly once we figure out a good extensibility story here
hostname.to_str().and_then(|hostname_str| {
dns_lookup::lookup_host(hostname_str)
.ok()
.and_then(|addresses| addresses.first().cloned())
ChannelTransport::Tcp(mode) => {
let ip = match mode {
TcpMode::Localhost => {
// Try IPv6 first, fall back to IPv4 if the system doesn't support IPv6
match std::net::TcpListener::bind((Ipv6Addr::UNSPECIFIED, 0)) {
Ok(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
Err(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
}
}
TcpMode::Hostname => hostname::get()
.ok()
.and_then(|hostname| {
// TODO: Avoid using DNS directly once we figure out a good extensibility story here
hostname.to_str().and_then(|hostname_str| {
dns_lookup::lookup_host(hostname_str)
.ok()
.and_then(|addresses| addresses.first().cloned())
})
})
})
.unwrap_or_else(|| IpAddr::from_str("::1").unwrap());
.expect("Failed to resolve hostname to IP address"),
};
Self::Tcp(SocketAddr::new(ip, 0))
}
ChannelTransport::MetaTls(mode) => {
Expand Down Expand Up @@ -542,7 +580,13 @@ impl ChannelAddr {
/// The transport used by this address.
pub fn transport(&self) -> ChannelTransport {
match self {
Self::Tcp(_) => ChannelTransport::Tcp,
Self::Tcp(addr) => {
if addr.ip().is_loopback() {
ChannelTransport::Tcp(TcpMode::Localhost)
} else {
ChannelTransport::Tcp(TcpMode::Hostname)
}
}
Self::MetaTls(addr) => match addr {
MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
Expand Down
7 changes: 6 additions & 1 deletion hyperactor/src/channel/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,12 @@ async fn join_nonempty<T: 'static>(set: &mut JoinSet<T>) -> Result<T, JoinError>
/// Tells whether the address is a 'net' address. These currently have different semantics
/// from local transports.
pub fn is_net_addr(addr: &ChannelAddr) -> bool {
[ChannelTransport::Tcp, ChannelTransport::Unix].contains(&addr.transport())
match addr.transport() {
// TODO Metatls?
ChannelTransport::Tcp(_) => true,
ChannelTransport::Unix => true,
_ => false,
}
}

pub(crate) mod unix {
Expand Down
6 changes: 5 additions & 1 deletion hyperactor_mesh/src/alloc/remoteprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use hyperactor::channel::ChannelRx;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::ChannelTx;
use hyperactor::channel::Rx;
use hyperactor::channel::TcpMode;
use hyperactor::channel::Tx;
use hyperactor::channel::TxStatus;
use hyperactor::clock;
Expand Down Expand Up @@ -768,7 +769,10 @@ impl RemoteProcessAlloc {
ChannelTransport::MetaTls(_) => {
format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
}
ChannelTransport::Tcp => {
ChannelTransport::Tcp(TcpMode::Localhost) => {
format!("tcp![::1]:{}", self.remote_allocator_port)
}
ChannelTransport::Tcp(TcpMode::Hostname) => {
format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
}
// Used only for testing.
Expand Down
3 changes: 2 additions & 1 deletion hyperactor_mesh/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,7 @@ mod tests {
use hyperactor::WorldId;
use hyperactor::channel::ChannelAddr;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::TcpMode;
use hyperactor::clock::RealClock;
use hyperactor::context::Mailbox as _;
use hyperactor::host::ProcHandle;
Expand Down Expand Up @@ -2073,7 +2074,7 @@ mod tests {
Bootstrap::default(),
Bootstrap::Proc {
proc_id: id!(foo[0]),
backend_addr: ChannelAddr::any(ChannelTransport::Tcp),
backend_addr: ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
callback_addr: ChannelAddr::any(ChannelTransport::Unix),
config: None,
},
Expand Down
17 changes: 9 additions & 8 deletions hyperactor_multiprocess/src/proc_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use hyperactor::actor::Referable;
use hyperactor::actor::remote::Remote;
use hyperactor::channel;
use hyperactor::channel::ChannelAddr;
use hyperactor::channel::TcpMode;
use hyperactor::clock::Clock;
use hyperactor::clock::ClockKind;
use hyperactor::context;
Expand Down Expand Up @@ -1376,7 +1377,7 @@ mod tests {

// Serve a system.
let server_handle = System::serve(
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
Duration::from_secs(120),
Duration::from_secs(120),
)
Expand All @@ -1395,7 +1396,7 @@ mod tests {
));

// Construct a proc forwarder in terms of the system sender.
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
let proc_forwarder =
BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));

Expand All @@ -1422,7 +1423,7 @@ mod tests {
let _proc_actor_1 = ProcActor::bootstrap_for_proc(
proc_1.clone(),
world_id.clone(),
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
server_handle.local_addr().clone(),
sup_ref.clone(),
Duration::from_secs(120),
Expand Down Expand Up @@ -1497,7 +1498,7 @@ mod tests {

// Serve a system.
let server_handle = System::serve(
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
Duration::from_secs(120),
Duration::from_secs(120),
)
Expand All @@ -1518,7 +1519,7 @@ mod tests {
));

// Construct a proc forwarder in terms of the system sender.
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
let proc_forwarder =
BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));

Expand All @@ -1545,7 +1546,7 @@ mod tests {
let _proc_actor_1 = ProcActor::bootstrap_for_proc(
proc_1.clone(),
world_id.clone(),
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
server_handle.local_addr().clone(),
sup_ref.clone(),
Duration::from_secs(120),
Expand Down Expand Up @@ -1651,7 +1652,7 @@ mod tests {
#[tokio::test]
async fn test_update_address_book_cache() {
let server_handle = System::serve(
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
Duration::from_secs(2), // supervision update timeout
Duration::from_secs(2), // duration to evict an unhealthy world
)
Expand Down Expand Up @@ -1705,7 +1706,7 @@ mod tests {
actor_id: &ActorId,
system_addr: &ChannelAddr,
) -> (ActorRef<PingPongActor>, ActorRef<ProcActor>) {
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp);
let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname));
let bootstrap = ProcActor::bootstrap(
actor_id.proc_id().clone(),
actor_id
Expand Down
3 changes: 2 additions & 1 deletion hyperactor_multiprocess/src/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ mod tests {
use hyperactor::WorldId;
use hyperactor::channel::ChannelAddr;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::TcpMode;
use hyperactor::clock::Clock;
use hyperactor::clock::RealClock;
use hyperactor_telemetry::env::execution_id;
Expand Down Expand Up @@ -825,7 +826,7 @@ mod tests {
#[tokio::test]
async fn test_channel_dial_count() {
let system_handle = System::serve(
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
Duration::from_secs(10),
Duration::from_secs(10),
)
Expand Down
9 changes: 5 additions & 4 deletions hyperactor_multiprocess/src/system_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,7 @@ mod tests {
use hyperactor::channel;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::Rx;
use hyperactor::channel::TcpMode;
use hyperactor::clock::Clock;
use hyperactor::clock::RealClock;
use hyperactor::data::Serialized;
Expand Down Expand Up @@ -2194,7 +2195,7 @@ mod tests {
// Serve a system. Undeliverable messages encountered by the
// mailbox server are returned to the system actor.
let server_handle = System::serve(
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
Duration::from_secs(2), // supervision update timeout
Duration::from_secs(2), // duration to evict an unhealthy world
)
Expand Down Expand Up @@ -2255,7 +2256,7 @@ mod tests {
let _proc_actor_0 = ProcActor::bootstrap_for_proc(
proc_0.clone(),
world_id.clone(),
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
server_handle.local_addr().clone(),
sup_ref.clone(),
Duration::from_millis(300), // supervision update interval
Expand All @@ -2272,7 +2273,7 @@ mod tests {
let proc_actor_1 = ProcActor::bootstrap_for_proc(
proc_1.clone(),
world_id.clone(),
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
server_handle.local_addr().clone(),
sup_ref.clone(),
Duration::from_millis(300), // supervision update interval
Expand Down Expand Up @@ -2348,7 +2349,7 @@ mod tests {
#[tokio::test]
async fn test_stop_fast() -> Result<()> {
let server_handle = System::serve(
ChannelAddr::any(ChannelTransport::Tcp),
ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
Duration::from_secs(2), // supervision update timeout
Duration::from_secs(2), // duration to evict an unhealthy world
)
Expand Down
Loading
Loading