diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index 3557af98f..5275732d2 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -28,6 +28,7 @@ use std::mem::take; use std::net::ToSocketAddrs; use std::pin::Pin; use std::sync::Arc; +use std::sync::LazyLock; use std::task::Poll; use backoff::ExponentialBackoffBuilder; @@ -171,6 +172,65 @@ fn build_codec() -> LengthDelimitedCodec { .new_codec() } +#[derive(thiserror::Error, Debug)] +enum HandshakeError { + #[error(transparent)] + DeserializationFailed(bincode::Error), + #[error("sending init frame failed: {0}")] + SendFailed(String), + #[error("timed out waiting for handshake response")] + Timeout, + #[error("channel closed waiting for handshake response")] + ChannelClosed, + #[error(transparent)] + IoError(io::Error), + #[error("mismatched rx session ids: {0} (stored) vs. {1} (received)")] + MismatchedRxSessionIds(u64, u64), +} + +async fn do_handshake( + tx_session_id: u64, + rx_session_id: &mut Option, + sink: &mut SplitSink, Bytes>, + stream: &mut SplitStream>, +) -> Result<(), HandshakeError> { + let data = bincode::serialize(&Frame::::Init(tx_session_id)) + .expect("unexpected serialization error"); + + if let Err(err) = sink.send(data.into()).await { + return Err(HandshakeError::SendFailed(format!("{}", err))); + } + + let handshake_response = tokio::select! { + _ = RealClock.sleep(config::global::get(config::MESSAGE_DELIVERY_TIMEOUT)) => None, + response = tokio_stream::StreamExt::next(stream) => Some(response) + }; + + match handshake_response { + None => Err(HandshakeError::Timeout), + Some(None) => Err(HandshakeError::ChannelClosed), + Some(Some(handshake_response)) => match handshake_response { + Err(err) => Err(HandshakeError::IoError(err)), + Ok(handshake_response) => match bincode::deserialize::(&handshake_response) { + Err(err) => Err(HandshakeError::DeserializationFailed(err)), + Ok(received_rx_session_id) => { + if rx_session_id + .is_some_and(|rx_session_id| rx_session_id != received_rx_session_id) + { + Err(HandshakeError::MismatchedRxSessionIds( + rx_session_id.unwrap(), + received_rx_session_id, + )) + } else { + rx_session_id.replace(received_rx_session_id); + Ok(()) + } + } + }, + }, + } +} + /// A Tx implemented on top of a Link. The Tx manages the link state, /// reconnections, etc. #[derive(Debug)] @@ -499,6 +559,7 @@ impl NetTx { } let session_id = rand::random(); + let mut rx_session_id = None; let log_id = format!("session {}.{}", link.dest(), session_id); let mut state = State::init(&log_id); let mut conn = Conn::reconnect_with_default(); @@ -716,10 +777,7 @@ impl NetTx { match link.connect().await { Ok(stream) => { let framed = Framed::new(stream, build_codec()); - let (mut sink, stream) = futures::StreamExt::split(framed); - let data = bincode::serialize(&Frame::::Init(session_id)) - .expect("unexpected serialization error"); - let initialized = sink.send(data.into()).await.is_ok(); + let (mut sink, mut stream) = futures::StreamExt::split(framed); metrics::CHANNEL_CONNECTIONS.add( 1, @@ -732,21 +790,71 @@ impl NetTx { // Need to resend unacked after reconnecting. let largest_acked = unacked.largest_acked; outbox.requeue_unacked(unacked); - ( - State::Running(Deliveries { - outbox, - // unacked messages are put back to outbox. So they are not - // considered as "sent yet unacked" message anymore. But - // we still want to keep `largest_acked` to known Rx's watermark. - unacked: Unacked::new(largest_acked, &log_id), - }), - if initialized { - backoff.reset(); - Conn::Connected { sink, stream } - } else { - Conn::reconnect(backoff) - }, + + let handshake_result = do_handshake::<_, M>( + session_id, + &mut rx_session_id, + &mut sink, + &mut stream, ) + .await; + + match handshake_result { + Ok(_) => { + ( + State::Running(Deliveries { + outbox, + // unacked messages are put back to outbox. So they are not + // considered as "sent yet unacked" message anymore. But + // we still want to keep `largest_acked` to known Rx's watermark. + unacked: Unacked::new(largest_acked, &log_id), + }), + { + backoff.reset(); + Conn::Connected { sink, stream } + }, + ) + } + Err(err) => match err { + HandshakeError::SendFailed(msg) => { + tracing::debug!( + "session {}.{}: handshake send failed: {}", + link.dest(), + session_id, + msg + ); + ( + State::Running(Deliveries { + outbox, + unacked: Unacked::new(largest_acked, &log_id), + }), + Conn::reconnect(backoff), + ) + } + err => { + let error_msg = format!( + "session {}.{}: handshake failed: {}", + link.dest(), + session_id, + err + ); + tracing::error!(error_msg); + ( + State::Closing { + deliveries: Deliveries { + outbox, + unacked: Unacked::new( + largest_acked, + &log_id, + ), + }, + reason: error_msg, + }, + Conn::reconnect_with_default(), + ) + } + }, + } } Err(err) => { tracing::debug!( @@ -1152,6 +1260,13 @@ impl ServerConn { let serialized = serialize_ack(next_seq - 1); futures::SinkExt::send(sink, serialized).await } + + pub async fn send_handshake_response(&mut self, rx_session_id: u64) -> anyhow::Result<()> { + let serialized = bincode::serialize(&rx_session_id)?; + futures::SinkExt::send(&mut self.sink, serialized.into()) + .await + .map_err(anyhow::Error::from) + } } /// An MVar is a primitive that combines synchronization and the exchange @@ -1240,6 +1355,10 @@ struct Next { ack: u64, } +/// Unique NetRx session identifier for the process. Used during handshake +/// to determine whether or not to allow a reconnection from a NetTx. +static GLOBAL_RX_SESSION_ID: LazyLock = LazyLock::new(rand::random); + /// Manages persistent sessions, ensuring that only one connection can own /// a session at a time, and arranging for session handover. #[derive(Clone)] @@ -1264,9 +1383,10 @@ impl SessionManager { S: AsyncRead + AsyncWrite + Send + 'static + Unpin, M: RemoteMessage, { - let session_id = conn.handshake::().await?; + let tx_session_id = conn.handshake::().await?; + conn.send_handshake_response(*GLOBAL_RX_SESSION_ID).await?; - let session_var = match self.sessions.entry(session_id) { + let session_var = match self.sessions.entry(tx_session_id) { Entry::Occupied(entry) => entry.get().clone(), Entry::Vacant(entry) => { // We haven't seen this session before. We begin with seq=0 and ack=0. @@ -1277,7 +1397,7 @@ impl SessionManager { }; let next = session_var.take().await; - let (next, res) = conn.process(session_id, tx, cancel_token, next).await; + let (next, res) = conn.process(tx_session_id, tx, cancel_token, next).await; session_var.put(next).await; if let Err(ref err) = res { @@ -2561,6 +2681,16 @@ mod tests { ) .await .unwrap(); + + // Process handshake + bincode::deserialize::( + tokio_stream::StreamExt::next(framed) + .await + .unwrap() + .unwrap() + .as_ref(), + ) + .unwrap(); } for (seq, message) in messages { @@ -2669,6 +2799,31 @@ mod tests { }; } + #[async_timed_test(timeout_secs = 60)] + async fn test_server_session_handshake_response() { + let config = config::global::lock(); + let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1); + + let manager = SessionManager::new(); + let (_handle, mut framed, _rx, _cancel_token) = serve::(&manager).await; + + framed + .send(bincode::serialize(&Frame::::Init(123)).unwrap().into()) + .await + .unwrap(); + + let rx_session_id = bincode::deserialize::( + tokio_stream::StreamExt::next(&mut framed) + .await + .unwrap() + .unwrap() + .as_ref(), + ) + .unwrap(); + + assert_eq!(rx_session_id, *GLOBAL_RX_SESSION_ID); + } + #[async_timed_test(timeout_secs = 60)] async fn test_ack_from_server_session() { let config = config::global::lock(); @@ -2765,8 +2920,11 @@ mod tests { assert_eq!(frame, expected, "from ln={loc}"); } + static RX_SESSION_ID: u64 = 456; + async fn verify_stream( stream: &mut SplitStream>, + sink: &mut SplitSink, Bytes>, expects: &[(u64, M)], expect_session_id: Option, loc: u32, @@ -2783,6 +2941,8 @@ mod tests { assert_eq!(session_id, expected_id, "from ln={loc}"); } + send_handshake_response(sink, RX_SESSION_ID).await; + for expect in expects { verify_message(stream, expect.clone(), loc).await; } @@ -2790,6 +2950,15 @@ mod tests { session_id } + async fn send_handshake_response( + sink: &mut SplitSink, Bytes>, + handshake_response: u64, + ) { + sink.send(bincode::serialize(&handshake_response).unwrap().into()) + .await + .unwrap(); + } + async fn net_tx_send(tx: &NetTx, msgs: &[u64]) { for msg in msgs { tx.try_post(*msg, unused_return_channel()).unwrap(); @@ -2809,6 +2978,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; let id = verify_stream( &mut stream, + &mut sink, &[(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)], None, line!(), @@ -2827,8 +2997,15 @@ mod tests { // Sent a new message to verify all sent messages will not be resent. net_tx_send(&tx, &[105]).await; { - let (_sink, mut stream) = take_receiver(&receiver_storage).await; - verify_stream(&mut stream, &[(5, 105)], Some(session_id), line!()).await; + let (mut sink, mut stream) = take_receiver(&receiver_storage).await; + verify_stream( + &mut stream, + &mut sink, + &[(5, 105)], + Some(session_id), + line!(), + ) + .await; // client DuplexStream is dropped here. This breaks the connection. }; } @@ -2855,6 +3032,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; let id = verify_stream( &mut stream, + &mut sink, &[(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)], session_id, line!(), @@ -2881,9 +3059,10 @@ mod tests { { let client = receiver_storage.take().await; let framed = Framed::new(client, build_codec()); - let (_sink, mut stream) = futures::StreamExt::split(framed); + let (mut sink, mut stream) = futures::StreamExt::split(framed); verify_stream( &mut stream, + &mut sink, &[(2, 102), (3, 103), (4, 104)], session_id, line!(), @@ -2902,6 +3081,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 1st send. (2, 102), @@ -2939,6 +3119,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 1st send. (4, 104), @@ -2966,9 +3147,10 @@ mod tests { for _ in 0..n { { - let (_sink, mut stream) = take_receiver(&receiver_storage).await; + let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 2nd send. (8, 108), @@ -2983,6 +3165,89 @@ mod tests { } } + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_timeout() { + let link = MockLink::::new(); + // Override the default (1m) for the purposes of this test. + let config = config::global::lock(); + let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1)); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + // Never send the handshake response + verify_tx_closed(&mut tx.status, "timed out waiting for handshake response").await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_incorrect_rx_session_id_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + { + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + send_handshake_response(&mut sink, 1234321).await; + // Wait for handshake response to be processed by NetTx before dropping sink/stream. Otherwise + // the channel will be closed and we will get the wrong error. + RealClock.sleep(tokio::time::Duration::from_secs(3)).await; + } + + { + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + send_handshake_response(&mut sink, 123).await; + // Ditto + RealClock.sleep(tokio::time::Duration::from_secs(3)).await; + } + + verify_tx_closed( + &mut tx.status, + format!( + "mismatched rx session ids: {} (stored) vs. {} (received)", + 1234321, 123 + ) + .as_str(), + ) + .await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_deserialization_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + sink.send(Bytes::from_static(b"bad")).await.unwrap(); + + verify_tx_closed( + &mut tx.status, + "handshake failed: io error: unexpected end of file", + ) + .await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_channel_closed_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + { + let _ = take_receiver(&receiver_storage).await; + } + + verify_tx_closed( + &mut tx.status, + "handshake failed: channel closed waiting for handshake response", + ) + .await; + } + #[async_timed_test(timeout_secs = 15)] async fn test_ack_before_redelivery_in_net_tx() { let link = MockLink::::new(); @@ -2994,7 +3259,7 @@ mod tests { let (return_channel_tx, return_channel_rx) = oneshot::channel(); net_tx.try_post(100, return_channel_tx).unwrap(); let (mut sink, mut stream) = take_receiver(&receiver_storage).await; - verify_stream(&mut stream, &[(0, 100)], None, line!()).await; + verify_stream(&mut stream, &mut sink, &[(0, 100)], None, line!()).await; // ack it sink.send(serialize_ack(0)).await.unwrap(); // confirm Tx received ack @@ -3037,7 +3302,7 @@ mod tests { tx.try_post(100, unused_return_channel()).unwrap(); let (mut sink, mut stream) = take_receiver(&receiver_storage).await; // Confirm message is sent to rx. - verify_stream(&mut stream, &[(0, 100)], None, line!()).await; + verify_stream(&mut stream, &mut sink, &[(0, 100)], None, line!()).await; // ack it sink.send(serialize_ack(0)).await.unwrap(); RealClock.sleep(Duration::from_secs(3)).await;