From 9737a01f53e0a9eb04f46aff7b6dd3620b8a6ff5 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 30 Jan 2025 20:45:34 +0000 Subject: [PATCH 1/3] Call `peer_disconnected` after a handler refuses a connection If one message handler refuses a connection by returning an `Err` from `peer_connected`, other handlers which already got the `peer_connected` will not see the corresponding `peer_disconnected`, leaving them in a potentially-inconsistent state. Here we ensure we call the `peer_disconnected` handler for all handlers which received a `peer_connected` event (except the one which refused the connection). --- lightning/src/ln/msgs.rs | 4 ++++ lightning/src/ln/peer_handler.rs | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index 659ec65f6cf..0c400fc36e0 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -1578,6 +1578,8 @@ pub trait ChannelMessageHandler : MessageSendEventsProvider { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>; /// Handle an incoming `channel_reestablish` message from the given peer. fn handle_channel_reestablish(&self, their_node_id: PublicKey, msg: &ChannelReestablish); @@ -1707,6 +1709,8 @@ pub trait OnionMessageHandler { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, init: &Init, inbound: bool) -> Result<(), ()>; /// Indicates a connection to the peer failed/an existing connection was lost. Allows handlers to diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 80b92cec1bd..c9ffabd78f6 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -88,6 +88,8 @@ pub trait CustomMessageHandler: wire::CustomMessageReader { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>; /// Gets the node feature flags which this handler itself supports. All available handlers are @@ -1718,10 +1720,13 @@ impl Date: Thu, 30 Jan 2025 20:44:48 +0000 Subject: [PATCH 2/3] Add `RoutingMessageHandler::peer_disconnected` ...to make it identical to all our other message handlers. --- lightning-net-tokio/src/lib.rs | 1 + lightning/src/ln/msgs.rs | 4 ++++ lightning/src/ln/peer_handler.rs | 6 ++++++ lightning/src/routing/gossip.rs | 2 ++ lightning/src/util/test_utils.rs | 2 ++ 5 files changed, 15 insertions(+) diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index 944033102c6..2ff88bc066a 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -689,6 +689,7 @@ mod tests { ) -> Result<(), ()> { Ok(()) } + fn peer_disconnected(&self, _their_node_id: PublicKey) {} fn handle_reply_channel_range( &self, _their_node_id: PublicKey, _msg: ReplyChannelRange, ) -> Result<(), LightningError> { diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index 0c400fc36e0..1323fab435f 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -1658,7 +1658,11 @@ pub trait RoutingMessageHandler : MessageSendEventsProvider { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, init: &Init, inbound: bool) -> Result<(), ()>; + /// Indicates a connection to the peer failed/an existing connection was lost. + fn peer_disconnected(&self, their_node_id: PublicKey); /// Handles the reply of a query we initiated to learn about channels /// for a given range of blocks. We can expect to receive one or more /// replies to a single query. diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index c9ffabd78f6..34abf6130f7 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -121,6 +121,7 @@ impl RoutingMessageHandler for IgnoringMessageHandler { Option<(msgs::ChannelAnnouncement, Option, Option)> { None } fn get_next_node_announcement(&self, _starting_point: Option<&NodeId>) -> Option { None } fn peer_connected(&self, _their_node_id: PublicKey, _init: &msgs::Init, _inbound: bool) -> Result<(), ()> { Ok(()) } + fn peer_disconnected(&self, _their_node_id: PublicKey) { } fn handle_reply_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), LightningError> { Ok(()) } fn handle_reply_short_channel_ids_end(&self, _their_node_id: PublicKey, _msg: msgs::ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) } fn handle_query_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::QueryChannelRange) -> Result<(), LightningError> { Ok(()) } @@ -1716,15 +1717,18 @@ impl Result<(), LightningError> { diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 48bdfe2324a..20981c8ffa7 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -1245,6 +1245,8 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { Ok(()) } + fn peer_disconnected(&self, _their_node_id: PublicKey) {} + fn handle_reply_channel_range( &self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange, ) -> Result<(), msgs::LightningError> { From 4bc597ae4b10c9f047c7c4ead6670c0b38ab6a94 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 30 Jan 2025 20:31:53 +0000 Subject: [PATCH 3/3] Assert `peer_{dis,}connected` consistency across test handlers This adds a `ConnectionTracker` test util which is used across `TestChannelMessageHandler`, `TestRoutingMessageHandler` and `TestCustomMessageHandler`, asserting that `peer_connected` and `peer_disconnected` methods are well-ordered. This expands test coverage from just `TestChannelMessageHandler` to cover all test handlers and adds some useful features which we'll use to test the fix in the next commit. This also adds an additional test which tests `peer_{dis,}connected` consistency when a handler refuses a connection by returning an `Err` from `peer_connected`. --- lightning/src/ln/peer_handler.rs | 103 ++++++++++++++++++++++++++----- lightning/src/util/test_utils.rs | 53 +++++++++++++--- 2 files changed, 134 insertions(+), 22 deletions(-) diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 34abf6130f7..8df168fee12 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -2867,6 +2867,16 @@ mod tests { struct TestCustomMessageHandler { features: InitFeatures, + conn_tracker: test_utils::ConnectionTracker, + } + + impl TestCustomMessageHandler { + fn new(features: InitFeatures) -> Self { + Self { + features, + conn_tracker: test_utils::ConnectionTracker::new(), + } + } } impl wire::CustomMessageReader for TestCustomMessageHandler { @@ -2883,10 +2893,13 @@ mod tests { fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() } + fn peer_disconnected(&self, their_node_id: PublicKey) { + self.conn_tracker.peer_disconnected(their_node_id); + } - fn peer_disconnected(&self, _their_node_id: PublicKey) {} - - fn peer_connected(&self, _their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { Ok(()) } + fn peer_connected(&self, their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { + self.conn_tracker.peer_connected(their_node_id) + } fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } @@ -2909,7 +2922,7 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)), logger: test_utils::TestLogger::with_id(i.to_string()), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler::new(features), node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -2932,7 +2945,7 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)), logger: test_utils::TestLogger::new(), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler::new(features), node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -2952,7 +2965,7 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(network), logger: test_utils::TestLogger::new(), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler::new(features), node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -2976,19 +2989,16 @@ mod tests { peers } - fn establish_connection<'a>(peer_a: &PeerManager, peer_b: &PeerManager) -> (FileDescriptor, FileDescriptor) { + fn try_establish_connection<'a>(peer_a: &PeerManager, peer_b: &PeerManager) -> (FileDescriptor, FileDescriptor, Result, Result) { + let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; + let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; + static FD_COUNTER: AtomicUsize = AtomicUsize::new(0); let fd = FD_COUNTER.fetch_add(1, Ordering::Relaxed) as u16; let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap(); let mut fd_a = FileDescriptor::new(fd); - let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; - - let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap(); - let features_a = peer_a.init_features(id_b); - let features_b = peer_b.init_features(id_a); let mut fd_b = FileDescriptor::new(fd); - let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); @@ -3000,11 +3010,30 @@ mod tests { peer_b.process_events(); let b_data = fd_b.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peer_a.read_event(&mut fd_a, &b_data).unwrap(), false); + let a_refused = peer_a.read_event(&mut fd_a, &b_data); peer_a.process_events(); let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peer_b.read_event(&mut fd_b, &a_data).unwrap(), false); + let b_refused = peer_b.read_event(&mut fd_b, &a_data); + + (fd_a, fd_b, a_refused, b_refused) + } + + + fn establish_connection<'a>(peer_a: &PeerManager, peer_b: &PeerManager) -> (FileDescriptor, FileDescriptor) { + let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; + let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; + + let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap(); + let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap(); + + let features_a = peer_a.init_features(id_b); + let features_b = peer_b.init_features(id_a); + + let (fd_a, fd_b, a_refused, b_refused) = try_establish_connection(peer_a, peer_b); + + assert_eq!(a_refused.unwrap(), false); + assert_eq!(b_refused.unwrap(), false); assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().counterparty_node_id, id_b); assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().socket_address, Some(addr_b)); @@ -3257,6 +3286,50 @@ mod tests { assert_eq!(peers[0].peers.read().unwrap().len(), 0); } + fn do_test_peer_connected_error_disconnects(handler: usize) { + // Test that if a message handler fails a connection in `peer_connected` we reliably + // produce `peer_disconnected` events for all other message handlers (that saw a + // corresponding `peer_connected`). + let cfgs = create_peermgr_cfgs(2); + let peers = create_network(2, &cfgs); + + match handler & !1 { + 0 => { + peers[handler & 1].message_handler.chan_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + } + 2 => { + peers[handler & 1].message_handler.route_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + } + 4 => { + peers[handler & 1].message_handler.custom_message_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + } + _ => panic!(), + } + let (_sd1, _sd2, a_refused, b_refused) = try_establish_connection(&peers[0], &peers[1]); + if handler & 1 == 0 { + assert!(a_refused.is_err()); + assert!(peers[0].list_peers().is_empty()); + } else { + assert!(b_refused.is_err()); + assert!(peers[1].list_peers().is_empty()); + } + // At least one message handler should have seen the connection. + assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.had_peers.load(Ordering::Acquire) || + peers[handler & 1].message_handler.route_handler.conn_tracker.had_peers.load(Ordering::Acquire) || + peers[handler & 1].message_handler.custom_message_handler.conn_tracker.had_peers.load(Ordering::Acquire)); + // And both message handlers doing tracking should see the disconnection + assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + assert!(peers[handler & 1].message_handler.route_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + assert!(peers[handler & 1].message_handler.custom_message_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + } + + #[test] + fn test_peer_connected_error_disconnects() { + for i in 0..6 { + do_test_peer_connected_error_disconnects(i); + } + } + #[test] fn test_do_attempt_write_data() { // Create 2 peers with custom TestRoutingMessageHandlers and connect them. diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 20981c8ffa7..2e89c51bd51 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -889,10 +889,45 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster { } } +pub struct ConnectionTracker { + pub had_peers: AtomicBool, + pub connected_peers: Mutex>, + pub fail_connections: AtomicBool, +} + +impl ConnectionTracker { + pub fn new() -> Self { + Self { + had_peers: AtomicBool::new(false), + connected_peers: Mutex::new(Vec::new()), + fail_connections: AtomicBool::new(false), + } + } + + pub fn peer_connected(&self, their_node_id: PublicKey) -> Result<(), ()> { + self.had_peers.store(true, Ordering::Release); + let mut connected_peers = self.connected_peers.lock().unwrap(); + assert!(!connected_peers.contains(&their_node_id)); + if self.fail_connections.load(Ordering::Acquire) { + Err(()) + } else { + connected_peers.push(their_node_id); + Ok(()) + } + } + + pub fn peer_disconnected(&self, their_node_id: PublicKey) { + assert!(self.had_peers.load(Ordering::Acquire)); + let mut connected_peers = self.connected_peers.lock().unwrap(); + assert!(connected_peers.contains(&their_node_id)); + connected_peers.retain(|id| *id != their_node_id); + } +} + pub struct TestChannelMessageHandler { pub pending_events: Mutex>, expected_recv_msgs: Mutex>>>, - connected_peers: Mutex>, + pub conn_tracker: ConnectionTracker, chain_hash: ChainHash, } @@ -907,7 +942,7 @@ impl TestChannelMessageHandler { TestChannelMessageHandler { pending_events: Mutex::new(Vec::new()), expected_recv_msgs: Mutex::new(None), - connected_peers: Mutex::new(new_hash_set()), + conn_tracker: ConnectionTracker::new(), chain_hash, } } @@ -1019,15 +1054,14 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { self.received_msg(wire::Message::ChannelReestablish(msg.clone())); } fn peer_disconnected(&self, their_node_id: PublicKey) { - assert!(self.connected_peers.lock().unwrap().remove(&their_node_id)); + self.conn_tracker.peer_disconnected(their_node_id) } fn peer_connected( &self, their_node_id: PublicKey, _msg: &msgs::Init, _inbound: bool, ) -> Result<(), ()> { - assert!(self.connected_peers.lock().unwrap().insert(their_node_id.clone())); // Don't bother with `received_msg` for Init as its auto-generated and we don't want to // bother re-generating the expected Init message in all tests. - Ok(()) + self.conn_tracker.peer_connected(their_node_id) } fn handle_error(&self, _their_node_id: PublicKey, msg: &msgs::ErrorMessage) { self.received_msg(wire::Message::Error(msg.clone())); @@ -1157,6 +1191,7 @@ pub struct TestRoutingMessageHandler { pub pending_events: Mutex>, pub request_full_sync: AtomicBool, pub announcement_available_for_sync: AtomicBool, + pub conn_tracker: ConnectionTracker, } impl TestRoutingMessageHandler { @@ -1168,6 +1203,7 @@ impl TestRoutingMessageHandler { pending_events, request_full_sync: AtomicBool::new(false), announcement_available_for_sync: AtomicBool::new(false), + conn_tracker: ConnectionTracker::new(), } } } @@ -1242,10 +1278,13 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { timestamp_range: u32::max_value(), }, }); - Ok(()) + + self.conn_tracker.peer_connected(their_node_id) } - fn peer_disconnected(&self, _their_node_id: PublicKey) {} + fn peer_disconnected(&self, their_node_id: PublicKey) { + self.conn_tracker.peer_disconnected(their_node_id); + } fn handle_reply_channel_range( &self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange,