Skip to content

Commit ca51ac0

Browse files
mariusaemeta-codesync[bot]
authored andcommitted
support socket addresses in metatls (#1397)
Summary: Pull Request resolved: #1397 Fix up address handling in general. This will also let us support wildcard addresses. ghstack-source-id: 313884234 exported-using-ghexport Reviewed By: zdevito Differential Revision: D83689820 fbshipit-source-id: bef02792f3207acb2311a5c795512e5687bd9cd2
1 parent 7601419 commit ca51ac0

File tree

5 files changed

+211
-55
lines changed

5 files changed

+211
-55
lines changed

hyper/src/utils/system_address.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::str::FromStr;
1010

1111
use anyhow;
1212
use hyperactor::channel::ChannelAddr;
13+
use hyperactor::channel::MetaTlsAddr;
1314

1415
/// Extended type to represent a system address which can be a ChannelAdd or a MAST job name.
1516
#[derive(Clone, Debug)]
@@ -62,7 +63,10 @@ async fn parse_system_address_or_mast_job(address: &str) -> Result<ChannelAddr,
6263
let (host, port) = SMCClient::new(fbinit::expect_init(), smc_tier)?
6364
.get_system_address()
6465
.await?;
65-
let channel_address = ChannelAddr::MetaTls(canonicalize_hostname(&host), port);
66+
let channel_address = ChannelAddr::MetaTls(MetaTlsAddr::Host {
67+
hostname: canonicalize_hostname(&host),
68+
port,
69+
});
6670
Ok(channel_address)
6771
}
6872
}

hyperactor/src/channel.rs

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use std::os::linux::net::SocketAddrExt;
1919
use std::str::FromStr;
2020

2121
use async_trait::async_trait;
22+
use enum_as_inner::EnumAsInner;
2223
use lazy_static::lazy_static;
2324
use local_ip_address::local_ipv6;
2425
use serde::Deserialize;
@@ -256,6 +257,60 @@ pub enum TlsMode {
256257
// TODO: consider adding IpV4 support.
257258
}
258259

260+
/// Address format for MetaTls channels. Supports both hostname/port pairs
261+
/// (required for clients for host identity) and direct socket addresses
262+
/// (allowed for servers).
263+
#[derive(
264+
Clone,
265+
Debug,
266+
PartialEq,
267+
Eq,
268+
Hash,
269+
Serialize,
270+
Deserialize,
271+
Ord,
272+
PartialOrd,
273+
EnumAsInner
274+
)]
275+
pub enum MetaTlsAddr {
276+
/// Hostname and port pair. Required for clients to establish host identity.
277+
Host {
278+
/// The hostname to connect to.
279+
hostname: Hostname,
280+
/// The port to connect to.
281+
port: Port,
282+
},
283+
/// Direct socket address. Allowed for servers.
284+
Socket(SocketAddr),
285+
}
286+
287+
impl MetaTlsAddr {
288+
/// Returns the port number for this address.
289+
pub fn port(&self) -> Port {
290+
match self {
291+
Self::Host { port, .. } => *port,
292+
Self::Socket(addr) => addr.port(),
293+
}
294+
}
295+
296+
/// Returns the hostname if this is a Host variant, None otherwise.
297+
pub fn hostname(&self) -> Option<&str> {
298+
match self {
299+
Self::Host { hostname, .. } => Some(hostname),
300+
Self::Socket(_) => None,
301+
}
302+
}
303+
}
304+
305+
impl fmt::Display for MetaTlsAddr {
306+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307+
match self {
308+
Self::Host { hostname, port } => write!(f, "{}:{}", hostname, port),
309+
Self::Socket(addr) => write!(f, "{}", addr),
310+
}
311+
}
312+
}
313+
259314
/// Types of channel transports.
260315
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)]
261316
pub enum ChannelTransport {
@@ -402,8 +457,9 @@ pub enum ChannelAddr {
402457
Tcp(SocketAddr),
403458

404459
/// An address to establish TCP channels with TLS support within Meta.
405-
/// Composed of hostname and port.
406-
MetaTls(Hostname, Port),
460+
/// Supports both hostname/port pairs (required for clients) and
461+
/// socket addresses (allowed for servers).
462+
MetaTls(MetaTlsAddr),
407463

408464
/// Local addresses are registered in-process and given an integral
409465
/// index.
@@ -471,7 +527,10 @@ impl ChannelAddr {
471527
.and_then(|addr| addr.to_string().parse().ok())
472528
.expect("failed to retrieve ipv6 address"),
473529
};
474-
Self::MetaTls(host_address, 0)
530+
Self::MetaTls(MetaTlsAddr::Host {
531+
hostname: host_address,
532+
port: 0,
533+
})
475534
}
476535
ChannelTransport::Local => Self::Local(0),
477536
ChannelTransport::Sim(transport) => sim::any(*transport),
@@ -484,12 +543,16 @@ impl ChannelAddr {
484543
pub fn transport(&self) -> ChannelTransport {
485544
match self {
486545
Self::Tcp(_) => ChannelTransport::Tcp,
487-
Self::MetaTls(address, _) => match address.parse::<IpAddr>() {
488-
Ok(ip) => match ip {
546+
Self::MetaTls(addr) => match addr {
547+
MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
548+
Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
549+
Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
550+
Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
551+
},
552+
MetaTlsAddr::Socket(socket_addr) => match socket_addr.ip() {
489553
IpAddr::V6(_) => ChannelTransport::MetaTls(TlsMode::IpV6),
490554
IpAddr::V4(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
491555
},
492-
Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
493556
},
494557
Self::Local(_) => ChannelTransport::Local,
495558
Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())),
@@ -502,7 +565,7 @@ impl fmt::Display for ChannelAddr {
502565
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
503566
match self {
504567
Self::Tcp(addr) => write!(f, "tcp:{}", addr),
505-
Self::MetaTls(hostname, port) => write!(f, "metatls:{}:{}", hostname, port),
568+
Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
506569
Self::Local(index) => write!(f, "local:{}", index),
507570
Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr),
508571
Self::Unix(addr) => write!(f, "unix:{}", addr),
@@ -630,7 +693,7 @@ pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, Channel
630693
let inner = match addr {
631694
ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
632695
ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr)),
633-
ChannelAddr::MetaTls(host, port) => ChannelTxKind::MetaTls(net::meta::dial(host, port)),
696+
ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?),
634697
ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
635698
ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
636699
};
@@ -649,8 +712,8 @@ pub fn serve<M: RemoteMessage>(
649712
let (addr, rx) = net::tcp::serve::<M>(addr)?;
650713
Ok((addr, ChannelRxKind::Tcp(rx)))
651714
}
652-
ChannelAddr::MetaTls(hostname, port) => {
653-
let (addr, rx) = net::meta::serve::<M>(hostname, port)?;
715+
ChannelAddr::MetaTls(meta_addr) => {
716+
let (addr, rx) = net::meta::serve::<M>(meta_addr)?;
654717
Ok((addr, ChannelRxKind::MetaTls(rx)))
655718
}
656719
ChannelAddr::Unix(path) => {

hyperactor/src/channel/net.rs

Lines changed: 123 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,8 @@ pub enum ClientError {
11751175
Io(ChannelAddr, std::io::Error),
11761176
#[error("send {0}: serialize: {1}")]
11771177
Serialize(ChannelAddr, bincode::ErrorKind),
1178+
#[error("invalid address: {0}")]
1179+
InvalidAddress(String),
11781180
}
11791181

11801182
#[derive(EnumAsInner)]
@@ -2175,14 +2177,23 @@ pub(crate) mod meta {
21752177

21762178
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
21772179
pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
2180+
// Try to parse as a socket address first
2181+
if let Ok(socket_addr) = addr_string.parse::<SocketAddr>() {
2182+
return Ok(ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket_addr)));
2183+
}
2184+
2185+
// Otherwise, parse as hostname:port
21782186
// use right split to allow for ipv6 addresses where ":" is expected.
21792187
let parts = addr_string.rsplit_once(":");
21802188
match parts {
21812189
Some((hostname, port_str)) => {
21822190
let Ok(port) = port_str.parse() else {
21832191
return Err(ChannelError::InvalidAddress(addr_string.to_string()));
21842192
};
2185-
Ok(ChannelAddr::MetaTls(hostname.to_string(), port))
2193+
Ok(ChannelAddr::MetaTls(MetaTlsAddr::Host {
2194+
hostname: hostname.to_string(),
2195+
port,
2196+
}))
21862197
}
21872198
_ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
21882199
}
@@ -2314,7 +2325,10 @@ pub(crate) mod meta {
23142325
type Stream = TlsStream<TcpStream>;
23152326

23162327
fn dest(&self) -> ChannelAddr {
2317-
ChannelAddr::MetaTls(self.hostname.clone(), self.port)
2328+
ChannelAddr::MetaTls(MetaTlsAddr::Host {
2329+
hostname: self.hostname.clone(),
2330+
port: self.port,
2331+
})
23182332
}
23192333

23202334
async fn connect(&self) -> Result<Self::Stream, ClientError> {
@@ -2353,42 +2367,93 @@ pub(crate) mod meta {
23532367
}
23542368
}
23552369

2356-
pub fn dial<M: RemoteMessage>(hostname: Hostname, port: Port) -> NetTx<M> {
2357-
NetTx::new(MetaLink { hostname, port })
2370+
pub fn dial<M: RemoteMessage>(addr: MetaTlsAddr) -> Result<NetTx<M>, ClientError> {
2371+
match addr {
2372+
MetaTlsAddr::Host { hostname, port } => Ok(NetTx::new(MetaLink { hostname, port })),
2373+
MetaTlsAddr::Socket(_) => Err(ClientError::InvalidAddress(
2374+
"MetaTls clients require hostname/port for host identity, not socket addresses"
2375+
.to_string(),
2376+
)),
2377+
}
23582378
}
23592379

2360-
/// Serve the given address with hostname and port. If port 0 is provided,
2361-
/// dynamic port will be resolved and is available on the returned ServerHandle.
2380+
/// Serve the given address. If port 0 is provided in a Host address,
2381+
/// a dynamic port will be resolved and is available in the returned ChannelAddr.
2382+
/// For Host addresses, binds to all resolved socket addresses.
23622383
pub fn serve<M: RemoteMessage>(
2363-
hostname: Hostname,
2364-
port: Port,
2384+
addr: MetaTlsAddr,
23652385
) -> Result<(ChannelAddr, NetRx<M>), ServerError> {
2366-
let mut addrs = (hostname.as_ref(), port).to_socket_addrs().map_err(|err| {
2367-
ServerError::Resolve(ChannelAddr::MetaTls(hostname.clone(), port), err)
2368-
})?;
2369-
let addr = addrs.next().ok_or(ServerError::Resolve(
2370-
ChannelAddr::MetaTls(hostname.clone(), port),
2371-
io::Error::other("no available socket addr"),
2372-
))?;
2373-
let channel_addr = ChannelAddr::MetaTls(hostname.clone(), port);
2386+
match addr {
2387+
MetaTlsAddr::Host { hostname, port } => {
2388+
// Resolve all addresses for the hostname
2389+
let addrs: Vec<SocketAddr> = (hostname.as_ref(), port)
2390+
.to_socket_addrs()
2391+
.map_err(|err| {
2392+
ServerError::Resolve(
2393+
ChannelAddr::MetaTls(MetaTlsAddr::Host {
2394+
hostname: hostname.clone(),
2395+
port,
2396+
}),
2397+
err,
2398+
)
2399+
})?
2400+
.collect();
2401+
2402+
if addrs.is_empty() {
2403+
return Err(ServerError::Resolve(
2404+
ChannelAddr::MetaTls(MetaTlsAddr::Host { hostname, port }),
2405+
io::Error::other("no available socket addr"),
2406+
));
2407+
}
23742408

2375-
// Go by way of a std listener to avoid making this function async.
2376-
let std_listener = std::net::TcpListener::bind(addr)
2377-
.map_err(|err| ServerError::Listen(channel_addr.clone(), err))?;
2378-
std_listener
2379-
.set_nonblocking(true)
2380-
.map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
2381-
let listener = TcpListener::from_std(std_listener)
2382-
.map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
2409+
let channel_addr = ChannelAddr::MetaTls(MetaTlsAddr::Host {
2410+
hostname: hostname.clone(),
2411+
port,
2412+
});
23832413

2384-
let local_addr = listener
2385-
.local_addr()
2386-
.map_err(|err| ServerError::Resolve(channel_addr, err))?;
2387-
super::serve(
2388-
listener,
2389-
ChannelAddr::MetaTls(hostname, local_addr.port()),
2390-
true,
2391-
)
2414+
// Bind to all resolved addresses
2415+
let std_listener = std::net::TcpListener::bind(&addrs[..])
2416+
.map_err(|err| ServerError::Listen(channel_addr.clone(), err))?;
2417+
std_listener
2418+
.set_nonblocking(true)
2419+
.map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
2420+
let listener = TcpListener::from_std(std_listener)
2421+
.map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
2422+
2423+
let local_addr = listener
2424+
.local_addr()
2425+
.map_err(|err| ServerError::Resolve(channel_addr, err))?;
2426+
super::serve(
2427+
listener,
2428+
ChannelAddr::MetaTls(MetaTlsAddr::Host {
2429+
hostname,
2430+
port: local_addr.port(),
2431+
}),
2432+
true,
2433+
)
2434+
}
2435+
MetaTlsAddr::Socket(socket_addr) => {
2436+
let channel_addr = ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket_addr));
2437+
2438+
// Bind directly to the socket address
2439+
let std_listener = std::net::TcpListener::bind(socket_addr)
2440+
.map_err(|err| ServerError::Listen(channel_addr.clone(), err))?;
2441+
std_listener
2442+
.set_nonblocking(true)
2443+
.map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
2444+
let listener = TcpListener::from_std(std_listener)
2445+
.map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
2446+
2447+
let local_addr = listener
2448+
.local_addr()
2449+
.map_err(|err| ServerError::Resolve(channel_addr, err))?;
2450+
super::serve(
2451+
listener,
2452+
ChannelAddr::MetaTls(MetaTlsAddr::Socket(local_addr)),
2453+
true,
2454+
)
2455+
}
2456+
}
23922457
}
23932458
}
23942459

@@ -2569,11 +2634,11 @@ mod tests {
25692634
#[tokio::test]
25702635
async fn test_meta_tls_basic() {
25712636
let addr = ChannelAddr::any(ChannelTransport::MetaTls(TlsMode::IpV6));
2572-
let (hostname, port) = match addr {
2573-
ChannelAddr::MetaTls(hostname, port) => (hostname, port),
2574-
_ => ("".to_string(), 0),
2637+
let meta_addr = match addr {
2638+
ChannelAddr::MetaTls(meta_addr) => meta_addr,
2639+
_ => panic!("expected MetaTls address"),
25752640
};
2576-
let (local_addr, mut rx) = net::meta::serve::<u64>(hostname, port).unwrap();
2641+
let (local_addr, mut rx) = net::meta::serve::<u64>(meta_addr).unwrap();
25772642
{
25782643
let tx = dial::<u64>(local_addr.clone()).unwrap();
25792644
tx.try_post(123, unused_return_channel()).unwrap();
@@ -3736,17 +3801,35 @@ mod tests {
37363801
fn test_metatls_parsing() {
37373802
// host:port
37383803
let channel: ChannelAddr = "metatls!localhost:1234".parse().unwrap();
3739-
assert_eq!(channel, ChannelAddr::MetaTls("localhost".to_string(), 1234));
3740-
// ipv4:port
3804+
assert_eq!(
3805+
channel,
3806+
ChannelAddr::MetaTls(MetaTlsAddr::Host {
3807+
hostname: "localhost".to_string(),
3808+
port: 1234
3809+
})
3810+
);
3811+
// ipv4:port - can be parsed as hostname or socket address
37413812
let channel: ChannelAddr = "metatls!1.2.3.4:1234".parse().unwrap();
3742-
assert_eq!(channel, ChannelAddr::MetaTls("1.2.3.4".to_string(), 1234));
3813+
assert_eq!(
3814+
channel,
3815+
ChannelAddr::MetaTls(MetaTlsAddr::Socket("1.2.3.4:1234".parse().unwrap()))
3816+
);
37433817
// ipv6:port
37443818
let channel: ChannelAddr = "metatls!2401:db00:33c:6902:face:0:2a2:0:1234"
37453819
.parse()
37463820
.unwrap();
37473821
assert_eq!(
37483822
channel,
3749-
ChannelAddr::MetaTls("2401:db00:33c:6902:face:0:2a2:0".to_string(), 1234)
3823+
ChannelAddr::MetaTls(MetaTlsAddr::Host {
3824+
hostname: "2401:db00:33c:6902:face:0:2a2:0".to_string(),
3825+
port: 1234
3826+
})
3827+
);
3828+
3829+
let channel: ChannelAddr = "metatls![::]:1234".parse().unwrap();
3830+
assert_eq!(
3831+
channel,
3832+
ChannelAddr::MetaTls(MetaTlsAddr::Socket("[::]:1234".parse().unwrap()))
37503833
);
37513834
}
37523835

0 commit comments

Comments
 (0)