@@ -2699,11 +2699,12 @@ mod tests {
26992699 use crate :: ln:: types:: ChannelId ;
27002700 use crate :: ln:: features:: { InitFeatures , NodeFeatures } ;
27012701 use crate :: ln:: peer_channel_encryptor:: PeerChannelEncryptor ;
2702- use crate :: ln:: peer_handler:: { CustomMessageHandler , PeerManager , MessageHandler , SocketDescriptor , IgnoringMessageHandler , filter_addresses, ErroringMessageHandler , MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER } ;
2702+ use crate :: ln:: peer_handler:: { CustomMessageHandler , OnionMessageHandler , PeerManager , MessageHandler , SocketDescriptor , IgnoringMessageHandler , filter_addresses, ErroringMessageHandler , MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER } ;
27032703 use crate :: ln:: { msgs, wire} ;
27042704 use crate :: ln:: msgs:: { Init , LightningError , SocketAddress } ;
27052705 use crate :: util:: test_utils;
27062706
2707+
27072708 use bitcoin:: Network ;
27082709 use bitcoin:: blockdata:: constants:: ChainHash ;
27092710 use bitcoin:: secp256k1:: { PublicKey , SecretKey } ;
@@ -2780,6 +2781,93 @@ mod tests {
27802781 }
27812782 }
27822783
2784+ struct TestPeerTrackingMessageHandler {
2785+ features : InitFeatures ,
2786+ pub peer_connected_called : Mutex < bool > ,
2787+ pub peer_disconnected_called : Mutex < bool > ,
2788+ }
2789+
2790+ impl TestPeerTrackingMessageHandler {
2791+ pub fn new ( features : InitFeatures ) -> Self {
2792+ Self {
2793+ features,
2794+ peer_connected_called : Mutex :: new ( false ) ,
2795+ peer_disconnected_called : Mutex :: new ( false ) ,
2796+ }
2797+ }
2798+ }
2799+
2800+ impl wire:: CustomMessageReader for TestPeerTrackingMessageHandler {
2801+ type CustomMessage = Infallible ;
2802+ fn read < R : io:: Read > ( & self , _: u16 , _: & mut R ) -> Result < Option < Self :: CustomMessage > , msgs:: DecodeError > {
2803+ Ok ( None )
2804+ }
2805+ }
2806+
2807+ impl CustomMessageHandler for TestPeerTrackingMessageHandler {
2808+ fn handle_custom_message ( & self , _: Infallible , _: & PublicKey ) -> Result < ( ) , LightningError > {
2809+ unreachable ! ( ) ;
2810+ }
2811+
2812+ fn get_and_clear_pending_msg ( & self ) -> Vec < ( PublicKey , Self :: CustomMessage ) > { Vec :: new ( ) }
2813+
2814+ fn peer_disconnected ( & self , _their_node_id : & PublicKey ) {
2815+ let connected = {
2816+ * self . peer_connected_called . lock ( ) . unwrap ( )
2817+ } ;
2818+ assert ! ( connected) ;
2819+
2820+ let mut disconnected = self . peer_disconnected_called . lock ( ) . unwrap ( ) ;
2821+ * disconnected = true ;
2822+ }
2823+
2824+ fn peer_connected ( & self , _their_node_id : & PublicKey , _msg : & Init , _inbound : bool ) -> Result < ( ) , ( ) > {
2825+ let disconnected = {
2826+ * self . peer_disconnected_called . lock ( ) . unwrap ( )
2827+ } ;
2828+ assert ! ( !disconnected) ;
2829+
2830+ let mut connected = self . peer_connected_called . lock ( ) . unwrap ( ) ;
2831+ assert ! ( !* connected) ;
2832+ * connected = true ;
2833+ Err ( ( ) )
2834+ }
2835+
2836+ fn provided_node_features ( & self ) -> NodeFeatures { NodeFeatures :: empty ( ) }
2837+
2838+ fn provided_init_features ( & self , _: & PublicKey ) -> InitFeatures {
2839+ self . features . clone ( )
2840+ }
2841+ }
2842+
2843+ impl OnionMessageHandler for TestPeerTrackingMessageHandler {
2844+ fn handle_onion_message ( & self , _peer_node_id : & PublicKey , _msg : & msgs:: OnionMessage ) { }
2845+ fn next_onion_message_for_peer ( & self , _peer_node_id : PublicKey ) -> Option < msgs:: OnionMessage > { None }
2846+ fn peer_connected ( & self , _their_node_id : & PublicKey , _init : & Init , _inbound : bool ) -> Result < ( ) , ( ) > {
2847+ let disconnected = {
2848+ * self . peer_disconnected_called . lock ( ) . unwrap ( )
2849+ } ;
2850+ assert ! ( !disconnected) ;
2851+
2852+ let mut connected = self . peer_connected_called . lock ( ) . unwrap ( ) ;
2853+ assert ! ( !* connected) ;
2854+ * connected = true ;
2855+ Err ( ( ) )
2856+ }
2857+ fn peer_disconnected ( & self , _their_node_id : & PublicKey ) {
2858+ let connected = {
2859+ * self . peer_connected_called . lock ( ) . unwrap ( )
2860+ } ;
2861+ assert ! ( connected) ;
2862+
2863+ let mut disconnected = self . peer_disconnected_called . lock ( ) . unwrap ( ) ;
2864+ * disconnected = true ;
2865+ }
2866+ fn timer_tick_occurred ( & self ) { }
2867+ fn provided_node_features ( & self ) -> NodeFeatures { NodeFeatures :: empty ( ) }
2868+ fn provided_init_features ( & self , _their_node_id : & PublicKey ) -> InitFeatures { self . features . clone ( ) }
2869+ }
2870+
27832871 fn create_peermgr_cfgs ( peer_count : usize ) -> Vec < PeerManagerCfg > {
27842872 let mut cfgs = Vec :: new ( ) ;
27852873 for i in 0 ..peer_count {
@@ -3164,6 +3252,103 @@ mod tests {
31643252 assert_eq ! ( peers[ 0 ] . peers. read( ) . unwrap( ) . len( ) , 0 ) ;
31653253 }
31663254
3255+ #[ test]
3256+ fn test_peer_connected_error_disconnects ( ) {
3257+
3258+ struct PeerTrackingPeerManagerConfig {
3259+ logger : test_utils:: TestLogger ,
3260+ node_signer : test_utils:: TestNodeSigner ,
3261+ chan_handler : test_utils:: TestChannelMessageHandler ,
3262+ route_handler : test_utils:: TestRoutingMessageHandler ,
3263+ onion_message_handler : TestPeerTrackingMessageHandler ,
3264+ custom_message_handler : TestPeerTrackingMessageHandler ,
3265+ }
3266+
3267+ fn create_cfgs ( peers : u8 ) -> Vec < PeerTrackingPeerManagerConfig > {
3268+ let mut cfgs = vec ! [ ] ;
3269+ for i in 0 ..peers {
3270+ let features = {
3271+ let mut feature_bits = vec ! [ 0u8 ; 33 ] ;
3272+ feature_bits[ 32 ] = 0b00000001 ;
3273+ InitFeatures :: from_le_bytes ( feature_bits)
3274+ } ;
3275+ let node_secret = SecretKey :: from_slice ( & [ 42 + i as u8 ; 32 ] ) . unwrap ( ) ;
3276+ cfgs. push ( PeerTrackingPeerManagerConfig {
3277+ logger : test_utils:: TestLogger :: new ( ) ,
3278+ node_signer : test_utils:: TestNodeSigner :: new ( node_secret) ,
3279+ chan_handler : test_utils:: TestChannelMessageHandler :: new ( ChainHash :: using_genesis_block ( Network :: Testnet ) ) ,
3280+ route_handler : test_utils:: TestRoutingMessageHandler :: new ( ) ,
3281+ onion_message_handler : TestPeerTrackingMessageHandler :: new ( features. clone ( ) ) ,
3282+ custom_message_handler : TestPeerTrackingMessageHandler :: new ( features. clone ( ) ) ,
3283+ } ) ;
3284+ }
3285+ cfgs
3286+ }
3287+
3288+ type PeerTrackingPeerManager < ' a > = PeerManager < FileDescriptor , & ' a test_utils:: TestChannelMessageHandler , & ' a test_utils:: TestRoutingMessageHandler , & ' a TestPeerTrackingMessageHandler , & ' a test_utils:: TestLogger , & ' a TestPeerTrackingMessageHandler , & ' a test_utils:: TestNodeSigner > ;
3289+
3290+ fn create_network < ' a > ( peer_count : usize , cfgs : & ' a Vec < PeerTrackingPeerManagerConfig > ) -> Vec < PeerTrackingPeerManager < ' a > > {
3291+ let mut peers = Vec :: new ( ) ;
3292+ for i in 0 ..peer_count {
3293+ let ephemeral_bytes = [ i as u8 ; 32 ] ;
3294+ let msg_handler = MessageHandler {
3295+ chan_handler : & cfgs[ i] . chan_handler , route_handler : & cfgs[ i] . route_handler ,
3296+ onion_message_handler : & cfgs[ i] . onion_message_handler , custom_message_handler : & cfgs[ i] . custom_message_handler
3297+ } ;
3298+ let peer = PeerManager :: new ( msg_handler, 0 , & ephemeral_bytes, & cfgs[ i] . logger , & cfgs[ i] . node_signer ) ;
3299+ peers. push ( peer) ;
3300+ }
3301+
3302+ peers
3303+ }
3304+
3305+ fn try_establish_connection < ' a > ( peer_a : & PeerTrackingPeerManager < ' a > , peer_b : & PeerTrackingPeerManager < ' a > ) -> ( FileDescriptor , FileDescriptor ) {
3306+ let id_a = peer_a. node_signer . get_node_id ( Recipient :: Node ) . unwrap ( ) ;
3307+ let mut fd_a = FileDescriptor {
3308+ fd : 1 , outbound_data : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
3309+ disconnect : Arc :: new ( AtomicBool :: new ( false ) ) ,
3310+ } ;
3311+ let addr_a = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1000 } ;
3312+ let mut fd_b = FileDescriptor {
3313+ fd : 1 , outbound_data : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
3314+ disconnect : Arc :: new ( AtomicBool :: new ( false ) ) ,
3315+ } ;
3316+ let addr_b = SocketAddress :: TcpIpV4 { addr : [ 127 , 0 , 0 , 1 ] , port : 1001 } ;
3317+ let initial_data = peer_b. new_outbound_connection ( id_a, fd_b. clone ( ) , Some ( addr_a. clone ( ) ) ) . unwrap ( ) ;
3318+ peer_a. new_inbound_connection ( fd_a. clone ( ) , Some ( addr_b. clone ( ) ) ) . unwrap ( ) ;
3319+
3320+ let _res = peer_a. read_event ( & mut fd_a, & initial_data) ;
3321+ peer_a. process_events ( ) ;
3322+
3323+ let a_data = fd_a. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
3324+
3325+ let _res = peer_b. read_event ( & mut fd_b, & a_data) ;
3326+
3327+ peer_b. process_events ( ) ;
3328+ let b_data = fd_b. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
3329+ let _res = peer_a. read_event ( & mut fd_a, & b_data) ;
3330+
3331+ peer_a. process_events ( ) ;
3332+ let a_data = fd_a. outbound_data . lock ( ) . unwrap ( ) . split_off ( 0 ) ;
3333+
3334+ let _res = peer_b. read_event ( & mut fd_b, & a_data) ;
3335+ ( fd_a. clone ( ) , fd_b. clone ( ) )
3336+ }
3337+
3338+ let cfgs = create_cfgs ( 2 ) ;
3339+ let peers = create_network ( 2 , & cfgs) ;
3340+ let ( _sd1, _sd2) = try_establish_connection ( & peers[ 0 ] , & peers[ 1 ] ) ;
3341+
3342+ let cmh_peer_connected_called = cfgs[ 0 ] . custom_message_handler . peer_connected_called . lock ( ) . unwrap ( ) ;
3343+ let cmh_peer_disconnected_called = cfgs[ 0 ] . custom_message_handler . peer_disconnected_called . lock ( ) . unwrap ( ) ;
3344+ let om_peer_connected_called = cfgs[ 0 ] . onion_message_handler . peer_connected_called . lock ( ) . unwrap ( ) ;
3345+ let om_peer_disconnected_called = cfgs[ 0 ] . onion_message_handler . peer_disconnected_called . lock ( ) . unwrap ( ) ;
3346+ assert ! ( * cmh_peer_connected_called) ;
3347+ assert ! ( * cmh_peer_disconnected_called) ;
3348+ assert ! ( * om_peer_connected_called) ;
3349+ assert ! ( * om_peer_disconnected_called) ;
3350+ }
3351+
31673352 #[ test]
31683353 fn test_do_attempt_write_data ( ) {
31693354 // Create 2 peers with custom TestRoutingMessageHandlers and connect them.
0 commit comments