diff --git a/hyper/src/commands/serve.rs b/hyper/src/commands/serve.rs index 37123f23a..3521ddc6b 100644 --- a/hyper/src/commands/serve.rs +++ b/hyper/src/commands/serve.rs @@ -10,6 +10,7 @@ use std::time::Duration; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; +use hyperactor::channel::TcpMode; use hyperactor_multiprocess::system::System; // The commands in the demo spawn temporary actors the join a system. @@ -27,7 +28,9 @@ pub struct ServeCommand { impl ServeCommand { pub async fn run(self) -> anyhow::Result<()> { - let addr = self.addr.unwrap_or(ChannelAddr::any(ChannelTransport::Tcp)); + let addr = self + .addr + .unwrap_or(ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname))); let handle = System::serve(addr, LONG_DURATION, LONG_DURATION).await?; eprintln!("serve: {}", handle.local_addr()); handle.await; diff --git a/hyperactor/benches/main.rs b/hyperactor/benches/main.rs index 1d85b6861..bae362d65 100644 --- a/hyperactor/benches/main.rs +++ b/hyperactor/benches/main.rs @@ -22,6 +22,7 @@ use hyperactor::channel; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::channel::Rx; +use hyperactor::channel::TcpMode; use hyperactor::channel::Tx; use hyperactor::channel::dial; use hyperactor::channel::serve; @@ -66,7 +67,7 @@ impl Message { fn bench_message_sizes(c: &mut Criterion) { let transports = vec![ ("local", ChannelTransport::Local), - ("tcp", ChannelTransport::Tcp), + ("tcp", ChannelTransport::Tcp(TcpMode::Hostname)), ("unix", ChannelTransport::Unix), ]; @@ -108,7 +109,7 @@ fn bench_message_rates(c: &mut Criterion) { let transports = vec![ ("local", ChannelTransport::Local), - ("tcp", ChannelTransport::Tcp), + ("tcp", ChannelTransport::Tcp(TcpMode::Hostname)), ("unix", ChannelTransport::Unix), //TODO Add TLS once it is able to run in Sandcastle ]; diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index 7afd84ebc..e9b7b7ccd 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -14,6 +14,8 @@ use core::net::SocketAddr; use std::fmt; use std::net::IpAddr; +use std::net::Ipv4Addr; +use std::net::Ipv6Addr; #[cfg(target_os = "linux")] use std::os::linux::net::SocketAddrExt; use std::str::FromStr; @@ -236,6 +238,26 @@ impl Rx for MpscRx { } } +/// The hostname to use for TLS connections. +#[derive( + Clone, + Debug, + PartialEq, + Eq, + Hash, + Serialize, + Deserialize, + strum::EnumIter, + strum::Display, + strum::EnumString +)] +pub enum TcpMode { + /// Use localhost/loopback for the connection. + Localhost, + /// Use host domain name for the connection. + Hostname, +} + /// The hostname to use for TLS connections. #[derive( Clone, @@ -315,7 +337,7 @@ impl fmt::Display for MetaTlsAddr { #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)] pub enum ChannelTransport { /// Transport over a TCP connection. - Tcp, + Tcp(TcpMode), /// Transport over a TCP connection with TLS support within Meta MetaTls(TlsMode), @@ -333,7 +355,7 @@ pub enum ChannelTransport { impl fmt::Display for ChannelTransport { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Tcp => write!(f, "tcp"), + Self::Tcp(mode) => write!(f, "tcp({:?})", mode), Self::MetaTls(mode) => write!(f, "metatls({:?})", mode), Self::Local => write!(f, "local"), Self::Sim(transport) => write!(f, "sim({})", transport), @@ -358,7 +380,13 @@ impl FromStr for ChannelTransport { } match s { - "tcp" => Ok(ChannelTransport::Tcp), + // Default to TcpMode::Hostname, if the mode isn't set + "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)), + s if s.starts_with("tcp(") => { + let inner = &s["tcp(".len()..s.len() - 1]; + let mode = inner.parse()?; + Ok(ChannelTransport::Tcp(mode)) + } "local" => Ok(ChannelTransport::Local), "unix" => Ok(ChannelTransport::Unix), s if s.starts_with("metatls(") && s.ends_with(")") => { @@ -373,9 +401,10 @@ impl FromStr for ChannelTransport { impl ChannelTransport { /// All known channel transports. - pub fn all() -> [ChannelTransport; 3] { + pub fn all() -> [ChannelTransport; 4] { [ - ChannelTransport::Tcp, + ChannelTransport::Tcp(TcpMode::Localhost), + ChannelTransport::Tcp(TcpMode::Hostname), ChannelTransport::Local, ChannelTransport::Unix, // TODO add MetaTls (T208303369) @@ -392,7 +421,7 @@ impl ChannelTransport { /// Returns true if this transport type represents a remote channel. pub fn is_remote(&self) -> bool { match self { - ChannelTransport::Tcp => true, + ChannelTransport::Tcp(_) => true, ChannelTransport::MetaTls(_) => true, ChannelTransport::Local => false, ChannelTransport::Sim(_) => false, @@ -502,18 +531,28 @@ impl ChannelAddr { /// servers to "any" address. pub fn any(transport: ChannelTransport) -> Self { match transport { - ChannelTransport::Tcp => { - let ip = hostname::get() - .ok() - .and_then(|hostname| { - // TODO: Avoid using DNS directly once we figure out a good extensibility story here - hostname.to_str().and_then(|hostname_str| { - dns_lookup::lookup_host(hostname_str) - .ok() - .and_then(|addresses| addresses.first().cloned()) + ChannelTransport::Tcp(mode) => { + let ip = match mode { + TcpMode::Localhost => { + // Fall back to 0.0.0.0 if localhost is not available, primarily for Docker + // TODO: @rusch for figuring out a better way to choose a bindable IP and port + match std::net::TcpListener::bind((Ipv6Addr::LOCALHOST, 0)) { + Ok(_) => IpAddr::V6(Ipv6Addr::LOCALHOST), + Err(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED), + } + } + TcpMode::Hostname => hostname::get() + .ok() + .and_then(|hostname| { + // TODO: Avoid using DNS directly once we figure out a good extensibility story here + hostname.to_str().and_then(|hostname_str| { + dns_lookup::lookup_host(hostname_str) + .ok() + .and_then(|addresses| addresses.first().cloned()) + }) }) - }) - .unwrap_or_else(|| IpAddr::from_str("::1").unwrap()); + .expect("Failed to resolve hostname to IP address"), + }; Self::Tcp(SocketAddr::new(ip, 0)) } ChannelTransport::MetaTls(mode) => { @@ -542,7 +581,13 @@ impl ChannelAddr { /// The transport used by this address. pub fn transport(&self) -> ChannelTransport { match self { - Self::Tcp(_) => ChannelTransport::Tcp, + Self::Tcp(addr) => { + if addr.ip().is_loopback() { + ChannelTransport::Tcp(TcpMode::Localhost) + } else { + ChannelTransport::Tcp(TcpMode::Hostname) + } + } Self::MetaTls(addr) => match addr { MetaTlsAddr::Host { hostname, .. } => match hostname.parse::() { Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6), diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index f38f64965..e9051487e 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -1799,7 +1799,12 @@ async fn join_nonempty(set: &mut JoinSet) -> Result /// Tells whether the address is a 'net' address. These currently have different semantics /// from local transports. pub fn is_net_addr(addr: &ChannelAddr) -> bool { - [ChannelTransport::Tcp, ChannelTransport::Unix].contains(&addr.transport()) + match addr.transport() { + // TODO Metatls? + ChannelTransport::Tcp(_) => true, + ChannelTransport::Unix => true, + _ => false, + } } pub(crate) mod unix { diff --git a/hyperactor_mesh/src/alloc/remoteprocess.rs b/hyperactor_mesh/src/alloc/remoteprocess.rs index 621ed3de5..056ea539b 100644 --- a/hyperactor_mesh/src/alloc/remoteprocess.rs +++ b/hyperactor_mesh/src/alloc/remoteprocess.rs @@ -26,6 +26,7 @@ use hyperactor::channel::ChannelRx; use hyperactor::channel::ChannelTransport; use hyperactor::channel::ChannelTx; use hyperactor::channel::Rx; +use hyperactor::channel::TcpMode; use hyperactor::channel::Tx; use hyperactor::channel::TxStatus; use hyperactor::clock; @@ -768,7 +769,10 @@ impl RemoteProcessAlloc { ChannelTransport::MetaTls(_) => { format!("metatls!{}:{}", host.hostname, self.remote_allocator_port) } - ChannelTransport::Tcp => { + ChannelTransport::Tcp(TcpMode::Localhost) => { + format!("tcp![::1]:{}", self.remote_allocator_port) + } + ChannelTransport::Tcp(TcpMode::Hostname) => { format!("tcp!{}:{}", host.hostname, self.remote_allocator_port) } // Used only for testing. diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index 656180c52..bcbbdbe9e 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -2043,6 +2043,7 @@ mod tests { use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; + use hyperactor::channel::TcpMode; use hyperactor::clock::RealClock; use hyperactor::context::Mailbox as _; use hyperactor::host::ProcHandle; @@ -2073,7 +2074,7 @@ mod tests { Bootstrap::default(), Bootstrap::Proc { proc_id: id!(foo[0]), - backend_addr: ChannelAddr::any(ChannelTransport::Tcp), + backend_addr: ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), callback_addr: ChannelAddr::any(ChannelTransport::Unix), config: None, }, diff --git a/hyperactor_multiprocess/src/proc_actor.rs b/hyperactor_multiprocess/src/proc_actor.rs index 03cc0da4e..3685484e3 100644 --- a/hyperactor_multiprocess/src/proc_actor.rs +++ b/hyperactor_multiprocess/src/proc_actor.rs @@ -34,6 +34,7 @@ use hyperactor::actor::Referable; use hyperactor::actor::remote::Remote; use hyperactor::channel; use hyperactor::channel::ChannelAddr; +use hyperactor::channel::TcpMode; use hyperactor::clock::Clock; use hyperactor::clock::ClockKind; use hyperactor::context; @@ -1376,7 +1377,7 @@ mod tests { // Serve a system. let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), Duration::from_secs(120), Duration::from_secs(120), ) @@ -1395,7 +1396,7 @@ mod tests { )); // Construct a proc forwarder in terms of the system sender. - let listen_addr = ChannelAddr::any(ChannelTransport::Tcp); + let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)); let proc_forwarder = BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender)); @@ -1422,7 +1423,7 @@ mod tests { let _proc_actor_1 = ProcActor::bootstrap_for_proc( proc_1.clone(), world_id.clone(), - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), server_handle.local_addr().clone(), sup_ref.clone(), Duration::from_secs(120), @@ -1497,7 +1498,7 @@ mod tests { // Serve a system. let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), Duration::from_secs(120), Duration::from_secs(120), ) @@ -1518,7 +1519,7 @@ mod tests { )); // Construct a proc forwarder in terms of the system sender. - let listen_addr = ChannelAddr::any(ChannelTransport::Tcp); + let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)); let proc_forwarder = BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender)); @@ -1545,7 +1546,7 @@ mod tests { let _proc_actor_1 = ProcActor::bootstrap_for_proc( proc_1.clone(), world_id.clone(), - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), server_handle.local_addr().clone(), sup_ref.clone(), Duration::from_secs(120), @@ -1651,7 +1652,7 @@ mod tests { #[tokio::test] async fn test_update_address_book_cache() { let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), Duration::from_secs(2), // supervision update timeout Duration::from_secs(2), // duration to evict an unhealthy world ) @@ -1705,7 +1706,7 @@ mod tests { actor_id: &ActorId, system_addr: &ChannelAddr, ) -> (ActorRef, ActorRef) { - let listen_addr = ChannelAddr::any(ChannelTransport::Tcp); + let listen_addr = ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)); let bootstrap = ProcActor::bootstrap( actor_id.proc_id().clone(), actor_id diff --git a/hyperactor_multiprocess/src/system.rs b/hyperactor_multiprocess/src/system.rs index f8e708a5c..79654662b 100644 --- a/hyperactor_multiprocess/src/system.rs +++ b/hyperactor_multiprocess/src/system.rs @@ -176,6 +176,7 @@ mod tests { use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; + use hyperactor::channel::TcpMode; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor_telemetry::env::execution_id; @@ -825,7 +826,7 @@ mod tests { #[tokio::test] async fn test_channel_dial_count() { let system_handle = System::serve( - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), Duration::from_secs(10), Duration::from_secs(10), ) diff --git a/hyperactor_multiprocess/src/system_actor.rs b/hyperactor_multiprocess/src/system_actor.rs index 05465ff22..51bdfc05f 100644 --- a/hyperactor_multiprocess/src/system_actor.rs +++ b/hyperactor_multiprocess/src/system_actor.rs @@ -1849,6 +1849,7 @@ mod tests { use hyperactor::channel; use hyperactor::channel::ChannelTransport; use hyperactor::channel::Rx; + use hyperactor::channel::TcpMode; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::data::Serialized; @@ -2194,7 +2195,7 @@ mod tests { // Serve a system. Undeliverable messages encountered by the // mailbox server are returned to the system actor. let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), Duration::from_secs(2), // supervision update timeout Duration::from_secs(2), // duration to evict an unhealthy world ) @@ -2255,7 +2256,7 @@ mod tests { let _proc_actor_0 = ProcActor::bootstrap_for_proc( proc_0.clone(), world_id.clone(), - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), server_handle.local_addr().clone(), sup_ref.clone(), Duration::from_millis(300), // supervision update interval @@ -2272,7 +2273,7 @@ mod tests { let proc_actor_1 = ProcActor::bootstrap_for_proc( proc_1.clone(), world_id.clone(), - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), server_handle.local_addr().clone(), sup_ref.clone(), Duration::from_millis(300), // supervision update interval @@ -2348,7 +2349,7 @@ mod tests { #[tokio::test] async fn test_stop_fast() -> Result<()> { let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Tcp), + ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)), Duration::from_secs(2), // supervision update timeout Duration::from_secs(2), // duration to evict an unhealthy world ) diff --git a/monarch_hyperactor/src/channel.rs b/monarch_hyperactor/src/channel.rs index 4064f7959..27d3e38a2 100644 --- a/monarch_hyperactor/src/channel.rs +++ b/monarch_hyperactor/src/channel.rs @@ -11,6 +11,7 @@ use std::str::FromStr; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::channel::MetaTlsAddr; +use hyperactor::channel::TcpMode; use hyperactor::channel::TlsMode; use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; @@ -24,7 +25,8 @@ use pyo3::prelude::*; )] #[derive(PartialEq, Clone, Copy, Debug)] pub enum PyChannelTransport { - Tcp, + TcpWithLocalhost, + TcpWithHostname, MetaTlsWithHostname, MetaTlsWithIpV6, Local, @@ -44,7 +46,8 @@ impl TryFrom for PyChannelTransport { fn try_from(transport: ChannelTransport) -> PyResult { match transport { - ChannelTransport::Tcp => Ok(PyChannelTransport::Tcp), + ChannelTransport::Tcp(TcpMode::Localhost) => Ok(PyChannelTransport::TcpWithLocalhost), + ChannelTransport::Tcp(TcpMode::Hostname) => Ok(PyChannelTransport::TcpWithHostname), ChannelTransport::MetaTls(TlsMode::Hostname) => { Ok(PyChannelTransport::MetaTlsWithHostname) } @@ -111,7 +114,10 @@ impl PyChannelAddr { pub fn get_transport(&self) -> PyResult { let transport = self.inner.transport(); match transport { - ChannelTransport::Tcp => Ok(PyChannelTransport::Tcp), + ChannelTransport::Tcp(mode) => match mode { + TcpMode::Localhost => Ok(PyChannelTransport::TcpWithLocalhost), + TcpMode::Hostname => Ok(PyChannelTransport::TcpWithHostname), + }, ChannelTransport::MetaTls(mode) => match mode { TlsMode::Hostname => Ok(PyChannelTransport::MetaTlsWithHostname), TlsMode::IpV6 => Ok(PyChannelTransport::MetaTlsWithIpV6), @@ -130,7 +136,8 @@ impl PyChannelAddr { impl From for ChannelTransport { fn from(val: PyChannelTransport) -> Self { match val { - PyChannelTransport::Tcp => ChannelTransport::Tcp, + PyChannelTransport::TcpWithLocalhost => ChannelTransport::Tcp(TcpMode::Localhost), + PyChannelTransport::TcpWithHostname => ChannelTransport::Tcp(TcpMode::Hostname), PyChannelTransport::MetaTlsWithHostname => ChannelTransport::MetaTls(TlsMode::Hostname), PyChannelTransport::MetaTlsWithIpV6 => ChannelTransport::MetaTls(TlsMode::IpV6), PyChannelTransport::Local => ChannelTransport::Local, @@ -154,7 +161,8 @@ mod tests { fn test_channel_any_and_parse() -> PyResult<()> { // just make sure any() and parse() calls work for all transports for transport in [ - PyChannelTransport::Tcp, + PyChannelTransport::TcpWithLocalhost, + PyChannelTransport::TcpWithHostname, PyChannelTransport::Unix, PyChannelTransport::MetaTlsWithHostname, PyChannelTransport::MetaTlsWithIpV6, @@ -190,9 +198,13 @@ mod tests { #[test] fn test_channel_addr_get_transport() -> PyResult<()> { + assert_eq!( + PyChannelAddr::parse("tcp![::1]:26600")?.get_transport()?, + PyChannelTransport::TcpWithLocalhost, + ); assert_eq!( PyChannelAddr::parse("tcp![::]:26600")?.get_transport()?, - PyChannelTransport::Tcp + PyChannelTransport::TcpWithHostname, ); assert_eq!( PyChannelAddr::parse("metatls!devgpu001.pci.facebook.com:26600")?.get_transport()?, diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi index bfec3fa53..fea7a1f5f 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi @@ -9,7 +9,8 @@ from enum import Enum class ChannelTransport(Enum): - Tcp = "tcp" + TcpWithLocalhost = "tcp(localhost)" + TcpWithHostname = "tcp(hostname)" MetaTlsWithHostname = "metatls(hostname)" MetaTlsWithIpV6 = "metatls(ipv6)" Local = "local" diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 9481e9e02..3df1f09da 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -248,7 +248,7 @@ def enable_transport(transport: "ChannelTransport | str") -> None: """ if isinstance(transport, str): transport = { - "tcp": ChannelTransport.Tcp, + "tcp": ChannelTransport.TcpWithHostname, "ipc": ChannelTransport.Unix, "metatls": ChannelTransport.MetaTlsWithIpV6, }.get(transport) diff --git a/python/tests/test_allocator.py b/python/tests/test_allocator.py index ef650b4dc..1d342868a 100644 --- a/python/tests/test_allocator.py +++ b/python/tests/test_allocator.py @@ -595,8 +595,6 @@ async def test_torchx_remote_alloc_initializer_no_match_label_gt_1_meshes( AllocSpec(AllocConstraints(), host=1, gpu=1) ).initialized - # Skipping test temporarily due to blocking OSS CI TODO: @rusch T232884876 - @pytest.mark.oss_skip # pyre-ignore[56]: Pyre cannot infer the type of this pytest marker async def test_torchx_remote_alloc_initializer_no_match_label_1_mesh(self) -> None: server = ServerSpec( name=UNUSED, @@ -627,8 +625,6 @@ async def test_torchx_remote_alloc_initializer_no_match_label_1_mesh(self) -> No ) self.assert_computed_world_size(results, 4) # 1x4 mesh - # Skipping test temporarily due to blocking OSS CI TODO: @rusch T232884876 - @pytest.mark.oss_skip # pyre-ignore[56]: Pyre cannot infer the type of this pytest marker async def test_torchx_remote_alloc_initializer_with_match_label(self) -> None: server = ServerSpec( name=UNUSED, diff --git a/python/tests/test_config.py b/python/tests/test_config.py index b20f3a4d1..8d80c5de6 100644 --- a/python/tests/test_config.py +++ b/python/tests/test_config.py @@ -17,7 +17,8 @@ def test_get_set_transport() -> None: for transport in ( ChannelTransport.Unix, - ChannelTransport.Tcp, + ChannelTransport.TcpWithLocalhost, + ChannelTransport.TcpWithHostname, ChannelTransport.MetaTlsWithHostname, ): configure(default_transport=transport)