@@ -1891,15 +1891,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
18911891			let  flush_read_disabled = self . gossip_processing_backlog_lifted . swap ( false ,  Ordering :: Relaxed ) ; 
18921892
18931893			let  mut  peers_to_disconnect = HashMap :: new ( ) ; 
1894- 			let  mut  events_generated = self . message_handler . chan_handler . get_and_clear_pending_msg_events ( ) ; 
1895- 			events_generated. append ( & mut  self . message_handler . route_handler . get_and_clear_pending_msg_events ( ) ) ; 
1896- 
18971894			{ 
1898- 				// TODO: There are some DoS attacks here where you can flood someone's outbound send 
1899- 				// buffer by doing things like announcing channels on another node. We should be willing to 
1900- 				// drop optional-ish messages when send buffers get full! 
1901- 
19021895				let  peers_lock = self . peers . read ( ) . unwrap ( ) ; 
1896+ 
1897+ 				let  mut  events_generated = self . message_handler . chan_handler . get_and_clear_pending_msg_events ( ) ; 
1898+ 				events_generated. append ( & mut  self . message_handler . route_handler . get_and_clear_pending_msg_events ( ) ) ; 
1899+ 
19031900				let  peers = & * peers_lock; 
19041901				macro_rules!  get_peer_for_forwarding { 
19051902					( $node_id:  expr)  => { 
@@ -2520,12 +2517,11 @@ mod tests {
25202517
25212518	use  crate :: prelude:: * ; 
25222519	use  crate :: sync:: { Arc ,  Mutex } ; 
2523- 	use  core:: convert:: Infallible ; 
2524- 	use  core:: sync:: atomic:: { AtomicBool ,  Ordering } ; 
2520+ 	use  core:: sync:: atomic:: { AtomicBool ,  AtomicUsize ,  Ordering } ; 
25252521
25262522	#[ derive( Clone ) ]  
25272523	struct  FileDescriptor  { 
2528- 		fd :  u16 , 
2524+ 		fd :  u32 , 
25292525		outbound_data :  Arc < Mutex < Vec < u8 > > > , 
25302526		disconnect :  Arc < AtomicBool > , 
25312527	} 
@@ -2560,24 +2556,44 @@ mod tests {
25602556
25612557	struct  TestCustomMessageHandler  { 
25622558		features :  InitFeatures , 
2559+ 		peer_counter :  AtomicUsize , 
2560+ 		send_messages :  Option < PublicKey > , 
2561+ 	} 
2562+ 
2563+ 	impl  crate :: ln:: wire:: Type  for  u64  { 
2564+ 		fn  type_id ( & self )  -> u16  {  4242  } 
25632565	} 
25642566
25652567	impl  wire:: CustomMessageReader  for  TestCustomMessageHandler  { 
2566- 		type  CustomMessage  = Infallible ; 
2567- 		fn  read < R :  io:: Read > ( & self ,  _:  u16 ,  _:  & mut  R )  -> Result < Option < Self :: CustomMessage > ,  msgs:: DecodeError >  { 
2568- 			Ok ( None ) 
2568+ 		type  CustomMessage  = u64 ; 
2569+ 		fn  read < R :  io:: Read > ( & self ,  msg_type :  u16 ,  reader :  & mut  R )  -> Result < Option < Self :: CustomMessage > ,  msgs:: DecodeError >  { 
2570+ 			assert ! ( self . send_messages. is_some( ) ) ; 
2571+ 			assert_eq ! ( msg_type,  4242 ) ; 
2572+ 			let  mut  msg = [ 0u8 ;  8 ] ; 
2573+ 			reader. read_exact ( & mut  msg) . unwrap ( ) ; 
2574+ 			Ok ( Some ( u64:: from_be_bytes ( msg) ) ) 
25692575		} 
25702576	} 
25712577
25722578	impl  CustomMessageHandler  for  TestCustomMessageHandler  { 
2573- 		fn  handle_custom_message ( & self ,  _:  Infallible ,  _:  & PublicKey )  -> Result < ( ) ,  LightningError >  { 
2574- 			unreachable ! ( ) ; 
2579+ 		fn  handle_custom_message ( & self ,  msg :  u64 ,  _:  & PublicKey )  -> Result < ( ) ,  LightningError >  { 
2580+ 			assert_eq ! ( self . peer_counter. load( Ordering :: Acquire )  as  u64 ,  msg) ; 
2581+ 			Ok ( ( ) ) 
25752582		} 
25762583
2577- 		fn  get_and_clear_pending_msg ( & self )  -> Vec < ( PublicKey ,  Self :: CustomMessage ) >  {  Vec :: new ( )  } 
2584+ 		fn  get_and_clear_pending_msg ( & self )  -> Vec < ( PublicKey ,  Self :: CustomMessage ) >  { 
2585+ 			if  let  Some ( peer_node_id)  = & self . send_messages  { 
2586+ 				vec ! [ ( * peer_node_id,  self . peer_counter. load( Ordering :: Acquire )  as  u64 ) ;  1000 ] 
2587+ 			}  else  {  Vec :: new ( )  } 
2588+ 		} 
25782589
2579- 		fn  peer_disconnected ( & self ,  _:  & PublicKey )  { } 
2580- 		fn  peer_connected ( & self ,  _:  & PublicKey ,  _:  & msgs:: Init ,  _:  bool )  -> Result < ( ) ,  ( ) >  {  Ok ( ( ) )  } 
2590+ 		fn  peer_disconnected ( & self ,  _:  & PublicKey )  { 
2591+ 			self . peer_counter . fetch_sub ( 1 ,  Ordering :: AcqRel ) ; 
2592+ 		} 
2593+ 		fn  peer_connected ( & self ,  _:  & PublicKey ,  _:  & msgs:: Init ,  _:  bool )  -> Result < ( ) ,  ( ) >  { 
2594+ 			self . peer_counter . fetch_add ( 2 ,  Ordering :: AcqRel ) ; 
2595+ 			Ok ( ( ) ) 
2596+ 		} 
25812597
25822598		fn  provided_node_features ( & self )  -> NodeFeatures  {  NodeFeatures :: empty ( )  } 
25832599
@@ -2600,7 +2616,9 @@ mod tests {
26002616					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) , 
26012617					logger :  test_utils:: TestLogger :: new ( ) , 
26022618					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2603- 					custom_handler :  TestCustomMessageHandler  {  features } , 
2619+ 					custom_handler :  TestCustomMessageHandler  { 
2620+ 						features,  peer_counter :  AtomicUsize :: new ( 0 ) ,  send_messages :  None , 
2621+ 					} , 
26042622					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
26052623				} 
26062624			) ; 
@@ -2623,7 +2641,9 @@ mod tests {
26232641					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) , 
26242642					logger :  test_utils:: TestLogger :: new ( ) , 
26252643					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2626- 					custom_handler :  TestCustomMessageHandler  {  features } , 
2644+ 					custom_handler :  TestCustomMessageHandler  { 
2645+ 						features,  peer_counter :  AtomicUsize :: new ( 0 ) ,  send_messages :  None , 
2646+ 					} , 
26272647					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
26282648				} 
26292649			) ; 
@@ -2643,7 +2663,9 @@ mod tests {
26432663					chan_handler :  test_utils:: TestChannelMessageHandler :: new ( network) , 
26442664					logger :  test_utils:: TestLogger :: new ( ) , 
26452665					routing_handler :  test_utils:: TestRoutingMessageHandler :: new ( ) , 
2646- 					custom_handler :  TestCustomMessageHandler  {  features } , 
2666+ 					custom_handler :  TestCustomMessageHandler  { 
2667+ 						features,  peer_counter :  AtomicUsize :: new ( 0 ) ,  send_messages :  None , 
2668+ 					} , 
26472669					node_signer :  test_utils:: TestNodeSigner :: new ( node_secret) , 
26482670				} 
26492671			) ; 
@@ -3191,4 +3213,100 @@ mod tests {
31913213		thread_c. join ( ) . unwrap ( ) ; 
31923214		assert ! ( cfg[ 0 ] . chan_handler. message_fetch_counter. load( Ordering :: Acquire )  >= 1 ) ; 
31933215	} 
3216+ 
3217+ 	#[ test]  
3218+ 	#[ cfg( feature = "std" ) ]  
3219+ 	fn  test_rapid_connect_events_order_multithreaded ( )  { 
3220+ 		// Previously, outbound messages held in `process_events` could race with peer 
3221+ 		// disconnection, allowing a message intended for a peer before disconnection to be sent 
3222+ 		// to the same peer after disconnection. Here we stress the handling of such messages by 
3223+ 		// connecting two peers repeatedly in a loop with a `CustomMessageHandler` set to stream 
3224+ 		// custom messages with a "connection id" to each other. That "connection id" (just the 
3225+ 		// number of reconnections seen) should always line up across both peers, which we assert 
3226+ 		// in the message handler. 
3227+ 		let  mut  cfg = create_peermgr_cfgs ( 2 ) ; 
3228+ 		cfg[ 0 ] . custom_handler . send_messages  =
3229+ 			Some ( cfg[ 1 ] . node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ) ; 
3230+ 		cfg[ 1 ] . custom_handler . send_messages  =
3231+ 			Some ( cfg[ 1 ] . node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ) ; 
3232+ 		let  cfg = Arc :: new ( cfg) ; 
3233+ 		// Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }. 
3234+ 		let  mut  peers = create_network ( 2 ,  unsafe  {  & * ( & * cfg as  * const  _ )  as  & ' static  _  } ) ; 
3235+ 		let  peer_a = Arc :: new ( peers. pop ( ) . unwrap ( ) ) ; 
3236+ 		let  peer_b = Arc :: new ( peers. pop ( ) . unwrap ( ) ) ; 
3237+ 
3238+ 		let  exit_flag = Arc :: new ( AtomicBool :: new ( false ) ) ; 
3239+ 		macro_rules!  spawn_thread {  ( $id:  expr)  => {  { 
3240+ 			let  thread_peer_a = Arc :: clone( & peer_a) ; 
3241+ 			let  thread_peer_b = Arc :: clone( & peer_b) ; 
3242+ 			let  thread_exit = Arc :: clone( & exit_flag) ; 
3243+ 			std:: thread:: spawn( move || { 
3244+ 				let  id_a = thread_peer_a. node_signer. get_node_id( Recipient :: Node ) . unwrap( ) ; 
3245+ 				let  mut  fd_a = FileDescriptor  { 
3246+ 					fd:  $id,  outbound_data:  Arc :: new( Mutex :: new( Vec :: new( ) ) ) , 
3247+ 					disconnect:  Arc :: new( AtomicBool :: new( false ) ) , 
3248+ 				} ; 
3249+ 				let  addr_a = SocketAddress :: TcpIpV4 { addr:  [ 127 ,  0 ,  0 ,  1 ] ,  port:  1000 } ; 
3250+ 				let  mut  fd_b = FileDescriptor  { 
3251+ 					fd:  $id,  outbound_data:  Arc :: new( Mutex :: new( Vec :: new( ) ) ) , 
3252+ 					disconnect:  Arc :: new( AtomicBool :: new( false ) ) , 
3253+ 				} ; 
3254+ 				let  addr_b = SocketAddress :: TcpIpV4 { addr:  [ 127 ,  0 ,  0 ,  1 ] ,  port:  1001 } ; 
3255+ 				let  initial_data = thread_peer_b. new_outbound_connection( id_a,  fd_b. clone( ) ,  Some ( addr_a. clone( ) ) ) . unwrap( ) ; 
3256+ 				thread_peer_a. new_inbound_connection( fd_a. clone( ) ,  Some ( addr_b. clone( ) ) ) . unwrap( ) ; 
3257+ 				if  thread_peer_a. read_event( & mut  fd_a,  & initial_data) . is_err( )  { 
3258+ 					thread_peer_b. socket_disconnected( & fd_b) ; 
3259+ 					return ; 
3260+ 				} 
3261+ 
3262+ 				loop  { 
3263+ 					if  thread_exit. load( Ordering :: Relaxed )  { 
3264+ 						thread_peer_a. socket_disconnected( & fd_a) ; 
3265+ 						thread_peer_b. socket_disconnected( & fd_b) ; 
3266+ 						return ; 
3267+ 					} 
3268+ 					if  fd_a. disconnect. load( Ordering :: Relaxed )  {  return ;  } 
3269+ 					if  fd_b. disconnect. load( Ordering :: Relaxed )  {  return ;  } 
3270+ 
3271+ 					let  data_a = fd_a. outbound_data. lock( ) . unwrap( ) . split_off( 0 ) ; 
3272+ 					if  !data_a. is_empty( )  { 
3273+ 						if  thread_peer_b. read_event( & mut  fd_b,  & data_a) . is_err( )  { 
3274+ 							thread_peer_a. socket_disconnected( & fd_a) ; 
3275+ 							return ; 
3276+ 						} 
3277+ 					} 
3278+ 
3279+ 					let  data_b = fd_b. outbound_data. lock( ) . unwrap( ) . split_off( 0 ) ; 
3280+ 					if  !data_b. is_empty( )  { 
3281+ 						if  thread_peer_a. read_event( & mut  fd_a,  & data_b) . is_err( )  { 
3282+ 							thread_peer_b. socket_disconnected( & fd_b) ; 
3283+ 							return ; 
3284+ 						} 
3285+ 					} 
3286+ 				} 
3287+ 			} ) 
3288+ 		}  }  } 
3289+ 
3290+ 		let  mut  threads = Vec :: new ( ) ; 
3291+ 		{ 
3292+ 			let  thread_peer_a = Arc :: clone ( & peer_a) ; 
3293+ 			let  thread_peer_b = Arc :: clone ( & peer_b) ; 
3294+ 			let  thread_exit = Arc :: clone ( & exit_flag) ; 
3295+ 			threads. push ( std:: thread:: spawn ( move  || { 
3296+ 				while  !thread_exit. load ( Ordering :: Relaxed )  { 
3297+ 					thread_peer_a. process_events ( ) ; 
3298+ 					thread_peer_b. process_events ( ) ; 
3299+ 				} 
3300+ 			} ) ) ; 
3301+ 		} 
3302+ 		for  i in  0 ..1000  { 
3303+ 			threads. push ( spawn_thread ! ( i) ) ; 
3304+ 		} 
3305+ 		exit_flag. store ( true ,  Ordering :: Relaxed ) ; 
3306+ 		for  thread in  threads { 
3307+ 			thread. join ( ) . unwrap ( ) ; 
3308+ 		} 
3309+ 		assert_eq ! ( peer_a. peers. read( ) . unwrap( ) . len( ) ,  0 ) ; 
3310+ 		assert_eq ! ( peer_b. peers. read( ) . unwrap( ) . len( ) ,  0 ) ; 
3311+ 	} 
31943312} 
0 commit comments