Skip to content

Commit 2804756

Browse files
transports/tcp: Fix port reuse using Arc<RwLock> for listen_addrs (#2670)
Fix bug introduced in 2ad905f. Make sure set of listen addresses is shared between GenTcpConfig and TcpListenStream.
1 parent 4aa84bf commit 2804756

File tree

1 file changed

+65
-13
lines changed

1 file changed

+65
-13
lines changed

transports/tcp/src/lib.rs

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ use std::{
5959
io,
6060
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener},
6161
pin::Pin,
62+
sync::{Arc, RwLock},
6263
task::{Context, Poll},
6364
time::Duration,
6465
};
@@ -95,7 +96,7 @@ enum PortReuse {
9596
Enabled {
9697
/// The addresses and ports of the listening sockets
9798
/// registered as eligible for port reuse when dialing.
98-
listen_addrs: HashSet<(IpAddr, Port)>,
99+
listen_addrs: Arc<RwLock<HashSet<(IpAddr, Port)>>>,
99100
},
100101
}
101102

@@ -106,7 +107,10 @@ impl PortReuse {
106107
fn register(&mut self, ip: IpAddr, port: Port) {
107108
if let PortReuse::Enabled { listen_addrs } = self {
108109
log::trace!("Registering for port reuse: {}:{}", ip, port);
109-
listen_addrs.insert((ip, port));
110+
listen_addrs
111+
.write()
112+
.expect("`register()` and `unregister()` never panic while holding the lock")
113+
.insert((ip, port));
110114
}
111115
}
112116

@@ -116,7 +120,10 @@ impl PortReuse {
116120
fn unregister(&mut self, ip: IpAddr, port: Port) {
117121
if let PortReuse::Enabled { listen_addrs } = self {
118122
log::trace!("Unregistering for port reuse: {}:{}", ip, port);
119-
listen_addrs.remove(&(ip, port));
123+
listen_addrs
124+
.write()
125+
.expect("`register()` and `unregister()` never panic while holding the lock")
126+
.remove(&(ip, port));
120127
}
121128
}
122129

@@ -131,7 +138,11 @@ impl PortReuse {
131138
/// listening socket address is found.
132139
fn local_dial_addr(&self, remote_ip: &IpAddr) -> Option<SocketAddr> {
133140
if let PortReuse::Enabled { listen_addrs } = self {
134-
for (ip, port) in listen_addrs.iter() {
141+
for (ip, port) in listen_addrs
142+
.read()
143+
.expect("`local_dial_addr` never panic while holding the lock")
144+
.iter()
145+
{
135146
if ip.is_ipv4() == remote_ip.is_ipv4()
136147
&& ip.is_loopback() == remote_ip.is_loopback()
137148
{
@@ -286,7 +297,7 @@ where
286297
pub fn port_reuse(mut self, port_reuse: bool) -> Self {
287298
self.port_reuse = if port_reuse {
288299
PortReuse::Enabled {
289-
listen_addrs: HashSet::new(),
300+
listen_addrs: Arc::new(RwLock::new(HashSet::new())),
290301
}
291302
} else {
292303
PortReuse::Disabled
@@ -707,7 +718,7 @@ fn ip_to_multiaddr(ip: IpAddr, port: u16) -> Multiaddr {
707718
#[cfg(test)]
708719
mod tests {
709720
use super::*;
710-
use futures::channel::mpsc;
721+
use futures::channel::{mpsc, oneshot};
711722

712723
#[test]
713724
fn multiaddr_to_tcp_conversion() {
@@ -900,15 +911,28 @@ mod tests {
900911
fn port_reuse_dialing() {
901912
env_logger::try_init().ok();
902913

903-
async fn listener<T: Provider>(addr: Multiaddr, mut ready_tx: mpsc::Sender<Multiaddr>) {
914+
async fn listener<T: Provider>(
915+
addr: Multiaddr,
916+
mut ready_tx: mpsc::Sender<Multiaddr>,
917+
port_reuse_rx: oneshot::Receiver<Protocol<'_>>,
918+
) {
904919
let mut tcp = GenTcpConfig::<T>::new();
905920
let mut listener = tcp.listen_on(addr).unwrap();
906921
loop {
907922
match listener.next().await.unwrap().unwrap() {
908923
ListenerEvent::NewAddress(listen_addr) => {
909924
ready_tx.send(listen_addr).await.ok();
910925
}
911-
ListenerEvent::Upgrade { upgrade, .. } => {
926+
ListenerEvent::Upgrade {
927+
upgrade,
928+
local_addr: _,
929+
mut remote_addr,
930+
} => {
931+
// Receive the dialer tcp port reuse
932+
let remote_port_reuse = port_reuse_rx.await.unwrap();
933+
// And check it is the same as the remote port used for upgrade
934+
assert_eq!(remote_addr.pop().unwrap(), remote_port_reuse);
935+
912936
let mut upgrade = upgrade.await.unwrap();
913937
let mut buf = [0u8; 3];
914938
upgrade.read_exact(&mut buf).await.unwrap();
@@ -921,12 +945,29 @@ mod tests {
921945
}
922946
}
923947

924-
async fn dialer<T: Provider>(addr: Multiaddr, mut ready_rx: mpsc::Receiver<Multiaddr>) {
948+
async fn dialer<T: Provider>(
949+
addr: Multiaddr,
950+
mut ready_rx: mpsc::Receiver<Multiaddr>,
951+
port_reuse_tx: oneshot::Sender<Protocol<'_>>,
952+
) {
925953
let dest_addr = ready_rx.next().await.unwrap();
926954
let mut tcp = GenTcpConfig::<T>::new().port_reuse(true);
927955
let mut listener = tcp.clone().listen_on(addr).unwrap();
928956
match listener.next().await.unwrap().unwrap() {
929957
ListenerEvent::NewAddress(_) => {
958+
// Check that tcp and listener share the same port reuse SocketAddr
959+
let port_reuse_tcp = tcp.port_reuse.local_dial_addr(&listener.listen_addr.ip());
960+
let port_reuse_listener = listener
961+
.port_reuse
962+
.local_dial_addr(&listener.listen_addr.ip());
963+
assert!(port_reuse_tcp.is_some());
964+
assert_eq!(port_reuse_tcp, port_reuse_listener);
965+
966+
// Send the dialer tcp port reuse to the listener
967+
port_reuse_tx
968+
.send(Protocol::Tcp(port_reuse_tcp.unwrap().port()))
969+
.ok();
970+
930971
// Obtain a future socket through dialing
931972
let mut socket = tcp.dial(dest_addr).unwrap().await.unwrap();
932973
socket.write_all(&[0x1, 0x2, 0x3]).await.unwrap();
@@ -943,8 +984,9 @@ mod tests {
943984
#[cfg(feature = "async-io")]
944985
{
945986
let (ready_tx, ready_rx) = mpsc::channel(1);
946-
let listener = listener::<async_io::Tcp>(addr.clone(), ready_tx);
947-
let dialer = dialer::<async_io::Tcp>(addr.clone(), ready_rx);
987+
let (port_reuse_tx, port_reuse_rx) = oneshot::channel();
988+
let listener = listener::<async_io::Tcp>(addr.clone(), ready_tx, port_reuse_rx);
989+
let dialer = dialer::<async_io::Tcp>(addr.clone(), ready_rx, port_reuse_tx);
948990
let listener = async_std::task::spawn(listener);
949991
async_std::task::block_on(dialer);
950992
async_std::task::block_on(listener);
@@ -953,8 +995,9 @@ mod tests {
953995
#[cfg(feature = "tokio")]
954996
{
955997
let (ready_tx, ready_rx) = mpsc::channel(1);
956-
let listener = listener::<tokio::Tcp>(addr.clone(), ready_tx);
957-
let dialer = dialer::<tokio::Tcp>(addr.clone(), ready_rx);
998+
let (port_reuse_tx, port_reuse_rx) = oneshot::channel();
999+
let listener = listener::<tokio::Tcp>(addr.clone(), ready_tx, port_reuse_rx);
1000+
let dialer = dialer::<tokio::Tcp>(addr.clone(), ready_rx, port_reuse_tx);
9581001
let rt = tokio_crate::runtime::Builder::new_current_thread()
9591002
.enable_io()
9601003
.build()
@@ -979,6 +1022,15 @@ mod tests {
9791022
let mut listener1 = tcp.clone().listen_on(addr).unwrap();
9801023
match listener1.next().await.unwrap().unwrap() {
9811024
ListenerEvent::NewAddress(addr1) => {
1025+
// Check that tcp and listener share the same port reuse SocketAddr
1026+
let port_reuse_tcp =
1027+
tcp.port_reuse.local_dial_addr(&listener1.listen_addr.ip());
1028+
let port_reuse_listener1 = listener1
1029+
.port_reuse
1030+
.local_dial_addr(&listener1.listen_addr.ip());
1031+
assert!(port_reuse_tcp.is_some());
1032+
assert_eq!(port_reuse_tcp, port_reuse_listener1);
1033+
9821034
// Listen on the same address a second time.
9831035
let mut listener2 = tcp.clone().listen_on(addr1.clone()).unwrap();
9841036
match listener2.next().await.unwrap().unwrap() {

0 commit comments

Comments
 (0)