From de46baca6cc4dd055121058c7bda2d536262ff12 Mon Sep 17 00:00:00 2001 From: Sam Lurye Date: Tue, 19 Aug 2025 12:50:43 -0700 Subject: [PATCH] Make NetTx<->NetRx handshake two-way to prevent errant reconnections (#793) Summary: Pull Request resolved: https://github.com/meta-pytorch/monarch/pull/793 The purpose of this diff is to handle the following scenario: 1. Process A starts serving a NetRx. 2. Process B creates a NetTx that connects to process A's NetRx. 3. B sends a few messages to A, and the messages are acked. 4. Process A dies/is killed, while B stays alive. 5. A new Process C starts serving a NetRx on the same channel as from step 1. 6. B's NetTx connects to C's NetRx, *with no way of knowing it has connected to a different process than before*. 7. B sends messages to C, starting from where it left off with A. 8. C rejects all of B's messages because of invalid sequence numbers. 9. B's NetTx eventually times out after a long time with no acks. In order to distinguish among connections from different NetTx instances to the same NetRx instance, each NetTx generates a random unique session id. This session id gets sent as part of an initial handshake from NetTx -> NetRx before the NetTx starts sending normal messages. Currently, though, NetTx doesn't wait for any handshake before starting to send messages. To resolve the issue described above, this diff introduces a global (per-process) "rx session id". When a NetTx first connects to a NetRx, the NetRx responds with its rx session id as part of the handshake. The NetTx waits for the handshake response and extracts the rx session id. If this is the first time the NetTx is connecting, the NetTx stores the rx session id. On subsequent connection attempts, the NetTx will validate the rx session id it receives from the handshake against the rx session id it previously stored; if there is a mismatch, the NetTx returns the appropriate error to its caller. Differential Revision: D79607092 --- hyperactor/src/channel/net.rs | 319 +++++++++++++++++++++++++++++++--- 1 file changed, 292 insertions(+), 27 deletions(-) diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index 3557af98f..5275732d2 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -28,6 +28,7 @@ use std::mem::take; use std::net::ToSocketAddrs; use std::pin::Pin; use std::sync::Arc; +use std::sync::LazyLock; use std::task::Poll; use backoff::ExponentialBackoffBuilder; @@ -171,6 +172,65 @@ fn build_codec() -> LengthDelimitedCodec { .new_codec() } +#[derive(thiserror::Error, Debug)] +enum HandshakeError { + #[error(transparent)] + DeserializationFailed(bincode::Error), + #[error("sending init frame failed: {0}")] + SendFailed(String), + #[error("timed out waiting for handshake response")] + Timeout, + #[error("channel closed waiting for handshake response")] + ChannelClosed, + #[error(transparent)] + IoError(io::Error), + #[error("mismatched rx session ids: {0} (stored) vs. {1} (received)")] + MismatchedRxSessionIds(u64, u64), +} + +async fn do_handshake( + tx_session_id: u64, + rx_session_id: &mut Option, + sink: &mut SplitSink, Bytes>, + stream: &mut SplitStream>, +) -> Result<(), HandshakeError> { + let data = bincode::serialize(&Frame::::Init(tx_session_id)) + .expect("unexpected serialization error"); + + if let Err(err) = sink.send(data.into()).await { + return Err(HandshakeError::SendFailed(format!("{}", err))); + } + + let handshake_response = tokio::select! { + _ = RealClock.sleep(config::global::get(config::MESSAGE_DELIVERY_TIMEOUT)) => None, + response = tokio_stream::StreamExt::next(stream) => Some(response) + }; + + match handshake_response { + None => Err(HandshakeError::Timeout), + Some(None) => Err(HandshakeError::ChannelClosed), + Some(Some(handshake_response)) => match handshake_response { + Err(err) => Err(HandshakeError::IoError(err)), + Ok(handshake_response) => match bincode::deserialize::(&handshake_response) { + Err(err) => Err(HandshakeError::DeserializationFailed(err)), + Ok(received_rx_session_id) => { + if rx_session_id + .is_some_and(|rx_session_id| rx_session_id != received_rx_session_id) + { + Err(HandshakeError::MismatchedRxSessionIds( + rx_session_id.unwrap(), + received_rx_session_id, + )) + } else { + rx_session_id.replace(received_rx_session_id); + Ok(()) + } + } + }, + }, + } +} + /// A Tx implemented on top of a Link. The Tx manages the link state, /// reconnections, etc. #[derive(Debug)] @@ -499,6 +559,7 @@ impl NetTx { } let session_id = rand::random(); + let mut rx_session_id = None; let log_id = format!("session {}.{}", link.dest(), session_id); let mut state = State::init(&log_id); let mut conn = Conn::reconnect_with_default(); @@ -716,10 +777,7 @@ impl NetTx { match link.connect().await { Ok(stream) => { let framed = Framed::new(stream, build_codec()); - let (mut sink, stream) = futures::StreamExt::split(framed); - let data = bincode::serialize(&Frame::::Init(session_id)) - .expect("unexpected serialization error"); - let initialized = sink.send(data.into()).await.is_ok(); + let (mut sink, mut stream) = futures::StreamExt::split(framed); metrics::CHANNEL_CONNECTIONS.add( 1, @@ -732,21 +790,71 @@ impl NetTx { // Need to resend unacked after reconnecting. let largest_acked = unacked.largest_acked; outbox.requeue_unacked(unacked); - ( - State::Running(Deliveries { - outbox, - // unacked messages are put back to outbox. So they are not - // considered as "sent yet unacked" message anymore. But - // we still want to keep `largest_acked` to known Rx's watermark. - unacked: Unacked::new(largest_acked, &log_id), - }), - if initialized { - backoff.reset(); - Conn::Connected { sink, stream } - } else { - Conn::reconnect(backoff) - }, + + let handshake_result = do_handshake::<_, M>( + session_id, + &mut rx_session_id, + &mut sink, + &mut stream, ) + .await; + + match handshake_result { + Ok(_) => { + ( + State::Running(Deliveries { + outbox, + // unacked messages are put back to outbox. So they are not + // considered as "sent yet unacked" message anymore. But + // we still want to keep `largest_acked` to known Rx's watermark. + unacked: Unacked::new(largest_acked, &log_id), + }), + { + backoff.reset(); + Conn::Connected { sink, stream } + }, + ) + } + Err(err) => match err { + HandshakeError::SendFailed(msg) => { + tracing::debug!( + "session {}.{}: handshake send failed: {}", + link.dest(), + session_id, + msg + ); + ( + State::Running(Deliveries { + outbox, + unacked: Unacked::new(largest_acked, &log_id), + }), + Conn::reconnect(backoff), + ) + } + err => { + let error_msg = format!( + "session {}.{}: handshake failed: {}", + link.dest(), + session_id, + err + ); + tracing::error!(error_msg); + ( + State::Closing { + deliveries: Deliveries { + outbox, + unacked: Unacked::new( + largest_acked, + &log_id, + ), + }, + reason: error_msg, + }, + Conn::reconnect_with_default(), + ) + } + }, + } } Err(err) => { tracing::debug!( @@ -1152,6 +1260,13 @@ impl ServerConn { let serialized = serialize_ack(next_seq - 1); futures::SinkExt::send(sink, serialized).await } + + pub async fn send_handshake_response(&mut self, rx_session_id: u64) -> anyhow::Result<()> { + let serialized = bincode::serialize(&rx_session_id)?; + futures::SinkExt::send(&mut self.sink, serialized.into()) + .await + .map_err(anyhow::Error::from) + } } /// An MVar is a primitive that combines synchronization and the exchange @@ -1240,6 +1355,10 @@ struct Next { ack: u64, } +/// Unique NetRx session identifier for the process. Used during handshake +/// to determine whether or not to allow a reconnection from a NetTx. +static GLOBAL_RX_SESSION_ID: LazyLock = LazyLock::new(rand::random); + /// Manages persistent sessions, ensuring that only one connection can own /// a session at a time, and arranging for session handover. #[derive(Clone)] @@ -1264,9 +1383,10 @@ impl SessionManager { S: AsyncRead + AsyncWrite + Send + 'static + Unpin, M: RemoteMessage, { - let session_id = conn.handshake::().await?; + let tx_session_id = conn.handshake::().await?; + conn.send_handshake_response(*GLOBAL_RX_SESSION_ID).await?; - let session_var = match self.sessions.entry(session_id) { + let session_var = match self.sessions.entry(tx_session_id) { Entry::Occupied(entry) => entry.get().clone(), Entry::Vacant(entry) => { // We haven't seen this session before. We begin with seq=0 and ack=0. @@ -1277,7 +1397,7 @@ impl SessionManager { }; let next = session_var.take().await; - let (next, res) = conn.process(session_id, tx, cancel_token, next).await; + let (next, res) = conn.process(tx_session_id, tx, cancel_token, next).await; session_var.put(next).await; if let Err(ref err) = res { @@ -2561,6 +2681,16 @@ mod tests { ) .await .unwrap(); + + // Process handshake + bincode::deserialize::( + tokio_stream::StreamExt::next(framed) + .await + .unwrap() + .unwrap() + .as_ref(), + ) + .unwrap(); } for (seq, message) in messages { @@ -2669,6 +2799,31 @@ mod tests { }; } + #[async_timed_test(timeout_secs = 60)] + async fn test_server_session_handshake_response() { + let config = config::global::lock(); + let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1); + + let manager = SessionManager::new(); + let (_handle, mut framed, _rx, _cancel_token) = serve::(&manager).await; + + framed + .send(bincode::serialize(&Frame::::Init(123)).unwrap().into()) + .await + .unwrap(); + + let rx_session_id = bincode::deserialize::( + tokio_stream::StreamExt::next(&mut framed) + .await + .unwrap() + .unwrap() + .as_ref(), + ) + .unwrap(); + + assert_eq!(rx_session_id, *GLOBAL_RX_SESSION_ID); + } + #[async_timed_test(timeout_secs = 60)] async fn test_ack_from_server_session() { let config = config::global::lock(); @@ -2765,8 +2920,11 @@ mod tests { assert_eq!(frame, expected, "from ln={loc}"); } + static RX_SESSION_ID: u64 = 456; + async fn verify_stream( stream: &mut SplitStream>, + sink: &mut SplitSink, Bytes>, expects: &[(u64, M)], expect_session_id: Option, loc: u32, @@ -2783,6 +2941,8 @@ mod tests { assert_eq!(session_id, expected_id, "from ln={loc}"); } + send_handshake_response(sink, RX_SESSION_ID).await; + for expect in expects { verify_message(stream, expect.clone(), loc).await; } @@ -2790,6 +2950,15 @@ mod tests { session_id } + async fn send_handshake_response( + sink: &mut SplitSink, Bytes>, + handshake_response: u64, + ) { + sink.send(bincode::serialize(&handshake_response).unwrap().into()) + .await + .unwrap(); + } + async fn net_tx_send(tx: &NetTx, msgs: &[u64]) { for msg in msgs { tx.try_post(*msg, unused_return_channel()).unwrap(); @@ -2809,6 +2978,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; let id = verify_stream( &mut stream, + &mut sink, &[(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)], None, line!(), @@ -2827,8 +2997,15 @@ mod tests { // Sent a new message to verify all sent messages will not be resent. net_tx_send(&tx, &[105]).await; { - let (_sink, mut stream) = take_receiver(&receiver_storage).await; - verify_stream(&mut stream, &[(5, 105)], Some(session_id), line!()).await; + let (mut sink, mut stream) = take_receiver(&receiver_storage).await; + verify_stream( + &mut stream, + &mut sink, + &[(5, 105)], + Some(session_id), + line!(), + ) + .await; // client DuplexStream is dropped here. This breaks the connection. }; } @@ -2855,6 +3032,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; let id = verify_stream( &mut stream, + &mut sink, &[(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)], session_id, line!(), @@ -2881,9 +3059,10 @@ mod tests { { let client = receiver_storage.take().await; let framed = Framed::new(client, build_codec()); - let (_sink, mut stream) = futures::StreamExt::split(framed); + let (mut sink, mut stream) = futures::StreamExt::split(framed); verify_stream( &mut stream, + &mut sink, &[(2, 102), (3, 103), (4, 104)], session_id, line!(), @@ -2902,6 +3081,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 1st send. (2, 102), @@ -2939,6 +3119,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 1st send. (4, 104), @@ -2966,9 +3147,10 @@ mod tests { for _ in 0..n { { - let (_sink, mut stream) = take_receiver(&receiver_storage).await; + let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 2nd send. (8, 108), @@ -2983,6 +3165,89 @@ mod tests { } } + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_timeout() { + let link = MockLink::::new(); + // Override the default (1m) for the purposes of this test. + let config = config::global::lock(); + let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1)); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + // Never send the handshake response + verify_tx_closed(&mut tx.status, "timed out waiting for handshake response").await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_incorrect_rx_session_id_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + { + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + send_handshake_response(&mut sink, 1234321).await; + // Wait for handshake response to be processed by NetTx before dropping sink/stream. Otherwise + // the channel will be closed and we will get the wrong error. + RealClock.sleep(tokio::time::Duration::from_secs(3)).await; + } + + { + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + send_handshake_response(&mut sink, 123).await; + // Ditto + RealClock.sleep(tokio::time::Duration::from_secs(3)).await; + } + + verify_tx_closed( + &mut tx.status, + format!( + "mismatched rx session ids: {} (stored) vs. {} (received)", + 1234321, 123 + ) + .as_str(), + ) + .await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_deserialization_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + sink.send(Bytes::from_static(b"bad")).await.unwrap(); + + verify_tx_closed( + &mut tx.status, + "handshake failed: io error: unexpected end of file", + ) + .await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_channel_closed_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + { + let _ = take_receiver(&receiver_storage).await; + } + + verify_tx_closed( + &mut tx.status, + "handshake failed: channel closed waiting for handshake response", + ) + .await; + } + #[async_timed_test(timeout_secs = 15)] async fn test_ack_before_redelivery_in_net_tx() { let link = MockLink::::new(); @@ -2994,7 +3259,7 @@ mod tests { let (return_channel_tx, return_channel_rx) = oneshot::channel(); net_tx.try_post(100, return_channel_tx).unwrap(); let (mut sink, mut stream) = take_receiver(&receiver_storage).await; - verify_stream(&mut stream, &[(0, 100)], None, line!()).await; + verify_stream(&mut stream, &mut sink, &[(0, 100)], None, line!()).await; // ack it sink.send(serialize_ack(0)).await.unwrap(); // confirm Tx received ack @@ -3037,7 +3302,7 @@ mod tests { tx.try_post(100, unused_return_channel()).unwrap(); let (mut sink, mut stream) = take_receiver(&receiver_storage).await; // Confirm message is sent to rx. - verify_stream(&mut stream, &[(0, 100)], None, line!()).await; + verify_stream(&mut stream, &mut sink, &[(0, 100)], None, line!()).await; // ack it sink.send(serialize_ack(0)).await.unwrap(); RealClock.sleep(Duration::from_secs(3)).await;