Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion msg-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
#![cfg_attr(not(test), warn(unused_crate_dependencies))]

use std::time::SystemTime;
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
time::SystemTime,
};

use futures::future::BoxFuture;

Expand Down Expand Up @@ -33,3 +36,47 @@ pub mod constants {
pub const MiB: u32 = 1024 * KiB;
pub const GiB: u32 = 1024 * MiB;
}

/// Extension trait for `SocketAddr`.
pub trait SocketAddrExt: Sized {
/// Returns the unspecified IPv4 socket address, bound to port 0.
fn unspecified_v4() -> Self;

/// Returns the unspecified IPv6 socket address, bound to port 0.
fn unspecified_v6() -> Self;

/// Returns the unspecified socket address of the same family as `other`, bound to port 0.
fn as_unspecified(&self) -> Self;
}

impl SocketAddrExt for SocketAddr {
fn unspecified_v4() -> Self {
Self::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))
}

fn unspecified_v6() -> Self {
Self::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0))
}

fn as_unspecified(&self) -> Self {
match self {
Self::V4(_) => Self::unspecified_v4(),
Self::V6(_) => Self::unspecified_v6(),
}
}
}

/// Extension trait for IP addresses.
pub trait IpAddrExt: Sized {
/// Returns the localhost address of the same family as `other`.
fn as_localhost(&self) -> Self;
}

impl IpAddrExt for IpAddr {
fn as_localhost(&self) -> Self {
match self {
Self::V4(_) => Self::V4(Ipv4Addr::LOCALHOST),
Self::V6(_) => Self::V6(Ipv6Addr::LOCALHOST),
}
}
}
10 changes: 4 additions & 6 deletions msg-socket/src/sub/socket.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
collections::HashSet,
net::{IpAddr, Ipv4Addr, SocketAddr},
net::SocketAddr,
path::PathBuf,
pin::Pin,
sync::Arc,
Expand All @@ -14,7 +14,7 @@ use tokio::{
sync::mpsc,
};

use msg_common::JoinMap;
use msg_common::{IpAddrExt, JoinMap};
use msg_transport::{Address, Transport};

// ADDED: Import the specific SubStats struct for the API
Expand Down Expand Up @@ -61,8 +61,7 @@ where
// Some transport implementations (e.g. Quinn) can't dial an unspecified
// IP address, so replace it with localhost.
if endpoint.ip().is_unspecified() {
// TODO: support IPv6
endpoint.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST));
endpoint.set_ip(endpoint.ip().as_localhost());
}

self.connect_inner(endpoint).await
Expand All @@ -76,8 +75,7 @@ where
// Some transport implementations (e.g. Quinn) can't dial an unspecified
// IP address, so replace it with localhost.
if endpoint.ip().is_unspecified() {
// TODO: support IPv6
endpoint.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST));
endpoint.set_ip(endpoint.ip().as_localhost());
}

self.try_connect_inner(endpoint)
Expand Down
8 changes: 4 additions & 4 deletions msg-transport/src/quic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tracing::{debug, error};

use crate::{Acceptor, Transport, TransportExt};

use msg_common::async_error;
use msg_common::{SocketAddrExt, async_error};

mod config;
pub use config::{Config, ConfigBuilder};
Expand Down Expand Up @@ -67,10 +67,10 @@ impl Quic {
/// `addr` is given, the endpoint will be bound to the default address.
fn new_endpoint(
&self,
addr: Option<SocketAddr>,
addr: SocketAddr,
server_config: Option<quinn::ServerConfig>,
) -> Result<quinn::Endpoint, Error> {
let socket = UdpSocket::bind(addr.unwrap_or(SocketAddr::from(([0, 0, 0, 0], 0))))?;
let socket = UdpSocket::bind(addr)?;

let endpoint = quinn::Endpoint::new(
self.config.endpoint_config.clone(),
Expand Down Expand Up @@ -113,7 +113,7 @@ impl Transport<SocketAddr> for Quic {
let endpoint = if let Some(endpoint) = self.endpoint.clone() {
endpoint
} else {
let Ok(mut endpoint) = self.new_endpoint(None, None) else {
let Ok(mut endpoint) = self.new_endpoint(addr.as_unspecified(), None) else {
return async_error(Error::ClosedEndpoint);
};

Expand Down
Loading