@@ -59,6 +59,7 @@ use std::{
5959 io,
6060 net:: { IpAddr , Ipv4Addr , Ipv6Addr , SocketAddr , TcpListener } ,
6161 pin:: Pin ,
62+ sync:: { Arc , RwLock } ,
6263 task:: { Context , Poll } ,
6364 time:: Duration ,
6465} ;
@@ -95,7 +96,7 @@ enum PortReuse {
9596 Enabled {
9697 /// The addresses and ports of the listening sockets
9798 /// registered as eligible for port reuse when dialing.
98- listen_addrs : HashSet < ( IpAddr , Port ) > ,
99+ listen_addrs : Arc < RwLock < HashSet < ( IpAddr , Port ) > > > ,
99100 } ,
100101}
101102
@@ -106,7 +107,10 @@ impl PortReuse {
106107 fn register ( & mut self , ip : IpAddr , port : Port ) {
107108 if let PortReuse :: Enabled { listen_addrs } = self {
108109 log:: trace!( "Registering for port reuse: {}:{}" , ip, port) ;
109- listen_addrs. insert ( ( ip, port) ) ;
110+ listen_addrs
111+ . write ( )
112+ . expect ( "`register()` and `unregister()` never panic while holding the lock" )
113+ . insert ( ( ip, port) ) ;
110114 }
111115 }
112116
@@ -116,7 +120,10 @@ impl PortReuse {
116120 fn unregister ( & mut self , ip : IpAddr , port : Port ) {
117121 if let PortReuse :: Enabled { listen_addrs } = self {
118122 log:: trace!( "Unregistering for port reuse: {}:{}" , ip, port) ;
119- listen_addrs. remove ( & ( ip, port) ) ;
123+ listen_addrs
124+ . write ( )
125+ . expect ( "`register()` and `unregister()` never panic while holding the lock" )
126+ . remove ( & ( ip, port) ) ;
120127 }
121128 }
122129
@@ -131,7 +138,11 @@ impl PortReuse {
131138 /// listening socket address is found.
132139 fn local_dial_addr ( & self , remote_ip : & IpAddr ) -> Option < SocketAddr > {
133140 if let PortReuse :: Enabled { listen_addrs } = self {
134- for ( ip, port) in listen_addrs. iter ( ) {
141+ for ( ip, port) in listen_addrs
142+ . read ( )
143+ . expect ( "`local_dial_addr` never panic while holding the lock" )
144+ . iter ( )
145+ {
135146 if ip. is_ipv4 ( ) == remote_ip. is_ipv4 ( )
136147 && ip. is_loopback ( ) == remote_ip. is_loopback ( )
137148 {
@@ -286,7 +297,7 @@ where
286297 pub fn port_reuse ( mut self , port_reuse : bool ) -> Self {
287298 self . port_reuse = if port_reuse {
288299 PortReuse :: Enabled {
289- listen_addrs : HashSet :: new ( ) ,
300+ listen_addrs : Arc :: new ( RwLock :: new ( HashSet :: new ( ) ) ) ,
290301 }
291302 } else {
292303 PortReuse :: Disabled
@@ -707,7 +718,7 @@ fn ip_to_multiaddr(ip: IpAddr, port: u16) -> Multiaddr {
707718#[ cfg( test) ]
708719mod tests {
709720 use super :: * ;
710- use futures:: channel:: mpsc;
721+ use futures:: channel:: { mpsc, oneshot } ;
711722
712723 #[ test]
713724 fn multiaddr_to_tcp_conversion ( ) {
@@ -900,15 +911,28 @@ mod tests {
900911 fn port_reuse_dialing ( ) {
901912 env_logger:: try_init ( ) . ok ( ) ;
902913
903- async fn listener < T : Provider > ( addr : Multiaddr , mut ready_tx : mpsc:: Sender < Multiaddr > ) {
914+ async fn listener < T : Provider > (
915+ addr : Multiaddr ,
916+ mut ready_tx : mpsc:: Sender < Multiaddr > ,
917+ port_reuse_rx : oneshot:: Receiver < Protocol < ' _ > > ,
918+ ) {
904919 let mut tcp = GenTcpConfig :: < T > :: new ( ) ;
905920 let mut listener = tcp. listen_on ( addr) . unwrap ( ) ;
906921 loop {
907922 match listener. next ( ) . await . unwrap ( ) . unwrap ( ) {
908923 ListenerEvent :: NewAddress ( listen_addr) => {
909924 ready_tx. send ( listen_addr) . await . ok ( ) ;
910925 }
911- ListenerEvent :: Upgrade { upgrade, .. } => {
926+ ListenerEvent :: Upgrade {
927+ upgrade,
928+ local_addr : _,
929+ mut remote_addr,
930+ } => {
931+ // Receive the dialer tcp port reuse
932+ let remote_port_reuse = port_reuse_rx. await . unwrap ( ) ;
933+ // And check it is the same as the remote port used for upgrade
934+ assert_eq ! ( remote_addr. pop( ) . unwrap( ) , remote_port_reuse) ;
935+
912936 let mut upgrade = upgrade. await . unwrap ( ) ;
913937 let mut buf = [ 0u8 ; 3 ] ;
914938 upgrade. read_exact ( & mut buf) . await . unwrap ( ) ;
@@ -921,12 +945,29 @@ mod tests {
921945 }
922946 }
923947
924- async fn dialer < T : Provider > ( addr : Multiaddr , mut ready_rx : mpsc:: Receiver < Multiaddr > ) {
948+ async fn dialer < T : Provider > (
949+ addr : Multiaddr ,
950+ mut ready_rx : mpsc:: Receiver < Multiaddr > ,
951+ port_reuse_tx : oneshot:: Sender < Protocol < ' _ > > ,
952+ ) {
925953 let dest_addr = ready_rx. next ( ) . await . unwrap ( ) ;
926954 let mut tcp = GenTcpConfig :: < T > :: new ( ) . port_reuse ( true ) ;
927955 let mut listener = tcp. clone ( ) . listen_on ( addr) . unwrap ( ) ;
928956 match listener. next ( ) . await . unwrap ( ) . unwrap ( ) {
929957 ListenerEvent :: NewAddress ( _) => {
958+ // Check that tcp and listener share the same port reuse SocketAddr
959+ let port_reuse_tcp = tcp. port_reuse . local_dial_addr ( & listener. listen_addr . ip ( ) ) ;
960+ let port_reuse_listener = listener
961+ . port_reuse
962+ . local_dial_addr ( & listener. listen_addr . ip ( ) ) ;
963+ assert ! ( port_reuse_tcp. is_some( ) ) ;
964+ assert_eq ! ( port_reuse_tcp, port_reuse_listener) ;
965+
966+ // Send the dialer tcp port reuse to the listener
967+ port_reuse_tx
968+ . send ( Protocol :: Tcp ( port_reuse_tcp. unwrap ( ) . port ( ) ) )
969+ . ok ( ) ;
970+
930971 // Obtain a future socket through dialing
931972 let mut socket = tcp. dial ( dest_addr) . unwrap ( ) . await . unwrap ( ) ;
932973 socket. write_all ( & [ 0x1 , 0x2 , 0x3 ] ) . await . unwrap ( ) ;
@@ -943,8 +984,9 @@ mod tests {
943984 #[ cfg( feature = "async-io" ) ]
944985 {
945986 let ( ready_tx, ready_rx) = mpsc:: channel ( 1 ) ;
946- let listener = listener :: < async_io:: Tcp > ( addr. clone ( ) , ready_tx) ;
947- let dialer = dialer :: < async_io:: Tcp > ( addr. clone ( ) , ready_rx) ;
987+ let ( port_reuse_tx, port_reuse_rx) = oneshot:: channel ( ) ;
988+ let listener = listener :: < async_io:: Tcp > ( addr. clone ( ) , ready_tx, port_reuse_rx) ;
989+ let dialer = dialer :: < async_io:: Tcp > ( addr. clone ( ) , ready_rx, port_reuse_tx) ;
948990 let listener = async_std:: task:: spawn ( listener) ;
949991 async_std:: task:: block_on ( dialer) ;
950992 async_std:: task:: block_on ( listener) ;
@@ -953,8 +995,9 @@ mod tests {
953995 #[ cfg( feature = "tokio" ) ]
954996 {
955997 let ( ready_tx, ready_rx) = mpsc:: channel ( 1 ) ;
956- let listener = listener :: < tokio:: Tcp > ( addr. clone ( ) , ready_tx) ;
957- let dialer = dialer :: < tokio:: Tcp > ( addr. clone ( ) , ready_rx) ;
998+ let ( port_reuse_tx, port_reuse_rx) = oneshot:: channel ( ) ;
999+ let listener = listener :: < tokio:: Tcp > ( addr. clone ( ) , ready_tx, port_reuse_rx) ;
1000+ let dialer = dialer :: < tokio:: Tcp > ( addr. clone ( ) , ready_rx, port_reuse_tx) ;
9581001 let rt = tokio_crate:: runtime:: Builder :: new_current_thread ( )
9591002 . enable_io ( )
9601003 . build ( )
@@ -979,6 +1022,15 @@ mod tests {
9791022 let mut listener1 = tcp. clone ( ) . listen_on ( addr) . unwrap ( ) ;
9801023 match listener1. next ( ) . await . unwrap ( ) . unwrap ( ) {
9811024 ListenerEvent :: NewAddress ( addr1) => {
1025+ // Check that tcp and listener share the same port reuse SocketAddr
1026+ let port_reuse_tcp =
1027+ tcp. port_reuse . local_dial_addr ( & listener1. listen_addr . ip ( ) ) ;
1028+ let port_reuse_listener1 = listener1
1029+ . port_reuse
1030+ . local_dial_addr ( & listener1. listen_addr . ip ( ) ) ;
1031+ assert ! ( port_reuse_tcp. is_some( ) ) ;
1032+ assert_eq ! ( port_reuse_tcp, port_reuse_listener1) ;
1033+
9821034 // Listen on the same address a second time.
9831035 let mut listener2 = tcp. clone ( ) . listen_on ( addr1. clone ( ) ) . unwrap ( ) ;
9841036 match listener2. next ( ) . await . unwrap ( ) . unwrap ( ) {
0 commit comments