diff --git a/fuzz/fuzz_targets/protocol_reader.rs b/fuzz/fuzz_targets/protocol_reader.rs index 641c0fa..b1e4877 100644 --- a/fuzz/fuzz_targets/protocol_reader.rs +++ b/fuzz/fuzz_targets/protocol_reader.rs @@ -143,7 +143,8 @@ where let transport = MockTransport::new(transport_data); // setup messenger - let mut messenger = Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = + Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID), None); messenger.override_version_ranges(HashMap::from([( api_key, ApiVersionRange::new(api_version, api_version), diff --git a/src/client/mod.rs b/src/client/mod.rs index 7661332..5120bb6 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -49,6 +49,7 @@ pub struct ClientBuilder { sasl_config: Option, backoff_config: Arc, connect_timeout: Option, + timeout: Option, } impl ClientBuilder { @@ -63,6 +64,7 @@ impl ClientBuilder { sasl_config: None, backoff_config: Default::default(), connect_timeout: Some(Duration::from_secs(30)), + timeout: None, } } @@ -117,6 +119,15 @@ impl ClientBuilder { self } + /// Set the timeout on requests to the broker. + /// By setting this to `None`, requests will never time out unless + /// interrupted by an external event. + /// The default timeout is `None`. + pub fn timeout(mut self, timeout: Option) -> Self { + self.timeout = timeout; + self + } + /// Build [`Client`]. pub async fn build(self) -> Result { let brokers = Arc::new(BrokerConnector::new( @@ -129,6 +140,7 @@ impl ClientBuilder { self.max_message_size, Arc::clone(&self.backoff_config), self.connect_timeout, + self.timeout, )); brokers.refresh_metadata().await?; diff --git a/src/connection.rs b/src/connection.rs index 78e0d78..9d0392c 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -80,6 +80,7 @@ impl Display for MultiError { trait ConnectionHandler { type R: RequestHandler + Send + Sync; + #[allow(clippy::too_many_arguments, reason = "Method is internal")] fn connect( &self, client_id: Arc, @@ -87,6 +88,7 @@ trait ConnectionHandler { socks5_proxy: Option, sasl_config: Option, max_message_size: usize, + connection_timeout: Option, timeout: Option, ) -> impl Future>> + Send; } @@ -150,6 +152,7 @@ impl ConnectionHandler for BrokerRepresentation { socks5_proxy: Option, sasl_config: Option, max_message_size: usize, + connection_timeout: Option, timeout: Option, ) -> Result> { let url = self.url(); @@ -158,14 +161,19 @@ impl ConnectionHandler for BrokerRepresentation { url = url.as_str(), "Establishing new connection", ); - let transport = Transport::connect(&url, tls_config, socks5_proxy, timeout) + let transport = Transport::connect(&url, tls_config, socks5_proxy, connection_timeout) .await .map_err(|error| Error::Transport { broker: url.to_string(), error, })?; - let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id); + let mut messenger = Messenger::new( + BufStream::new(transport), + max_message_size, + client_id, + timeout, + ); messenger.sync_versions().await?; if let Some(sasl_config) = sasl_config { messenger.do_sasl(sasl_config).await?; @@ -217,6 +225,12 @@ pub struct BrokerConnector { /// Timeout for connection attempts to the broker. connect_timeout: Option, + + /// Timeout for requests. + /// + /// If set, requests will timeout after the given duration. + /// If not set, requests will not timeout. + timeout: Option, } impl BrokerConnector { @@ -230,6 +244,7 @@ impl BrokerConnector { max_message_size: usize, backoff_config: Arc, connect_timeout: Option, + timeout: Option, ) -> Self { Self { bootstrap_brokers, @@ -243,6 +258,7 @@ impl BrokerConnector { sasl_config, max_message_size, connect_timeout, + timeout, } } @@ -340,6 +356,7 @@ impl BrokerConnector { self.sasl_config.clone(), self.max_message_size, self.connect_timeout, + self.timeout, ) .await?; Ok(Some(connection)) @@ -455,6 +472,7 @@ impl BrokerCache for &BrokerConnector { self.sasl_config.clone(), self.max_message_size, self.connect_timeout, + self.timeout, ) .await?; @@ -493,6 +511,7 @@ async fn connect_to_a_broker_with_retry( sasl_config: Option, max_message_size: usize, connect_timeout: Option, + timeout: Option, ) -> Result> where B: ConnectionHandler + Send + Sync, @@ -513,6 +532,7 @@ where sasl_config.clone(), max_message_size, connect_timeout, + timeout, ) .await; @@ -825,6 +845,7 @@ mod tests { _sasl_config: Option, _max_message_size: usize, _connect_timeout: Option, + _timeout: Option, ) -> Result> { (self.conn)() } @@ -854,6 +875,7 @@ mod tests { Default::default(), Default::default(), None, + None, ) .await .unwrap(); diff --git a/src/messenger.rs b/src/messenger.rs index ab42083..b8e42f7 100644 --- a/src/messenger.rs +++ b/src/messenger.rs @@ -16,6 +16,7 @@ use rsasl::{ mechname::MechanismNameError, prelude::{Mechname, SASLError, SessionError}, }; +use std::time::Duration; use thiserror::Error; use tokio::{ io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf}, @@ -133,6 +134,12 @@ pub struct Messenger { /// Join handle for the background worker that fetches responses. join_handle: JoinHandle<()>, + + /// Timeout for requests. + /// + /// If set, requests will timeout after the given duration. + /// If not set, requests will not timeout. + timeout: Option, } #[derive(Error, Debug)] @@ -218,7 +225,12 @@ impl Messenger where RW: AsyncRead + AsyncWrite + Send + 'static, { - pub fn new(stream: RW, max_message_size: usize, client_id: Arc) -> Self { + pub fn new( + stream: RW, + max_message_size: usize, + client_id: Arc, + timeout: Option, + ) -> Self { let (stream_read, stream_write) = tokio::io::split(stream); let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default()))); let state_captured = Arc::clone(&state); @@ -304,6 +316,7 @@ where version_ranges: HashMap::new(), state, join_handle, + timeout, } } @@ -404,7 +417,21 @@ where self.send_message(buf).await?; cleanup_on_cancel.message_sent(); - let mut response = rx.await.expect("Who closed this channel?!")?; + let mut response = if let Some(timeout) = self.timeout { + // If a request times out, return a `RequestError::IO` with a timeout error. + // This allows the backoff mechanism to detect transport issues and re-establish the connection as needed. + // + // Typically, timeouts occur due to abrupt TCP connection loss (e.g., a disconnected cable). + tokio::time::timeout(timeout, rx).await.map_err(|_| { + RequestError::IO(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "Request timed out", + )) + })? + } else { + rx.await + } + .expect("Who closed this channel?!")?; let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?; // check if we fully consumed the message, otherwise there might be a bug in our protocol code @@ -818,7 +845,7 @@ mod tests { #[tokio::test] async fn test_sync_versions_ok() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); // construct response let mut msg = vec![]; @@ -855,7 +882,7 @@ mod tests { #[tokio::test] async fn test_sync_versions_ignores_error_code() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); // construct error response let mut msg = vec![]; @@ -918,7 +945,7 @@ mod tests { #[tokio::test] async fn test_sync_versions_ignores_read_code() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); // construct error response let mut msg = vec![]; @@ -969,7 +996,7 @@ mod tests { #[tokio::test] async fn test_sync_versions_err_flipped_range() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); // construct response let mut msg = vec![]; @@ -1002,7 +1029,7 @@ mod tests { #[tokio::test] async fn test_sync_versions_ignores_garbage() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); // construct response let mut msg = vec![]; @@ -1066,7 +1093,7 @@ mod tests { #[tokio::test] async fn test_sync_versions_err_no_working_version() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); // construct error response for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0.0) @@ -1105,7 +1132,7 @@ mod tests { #[tokio::test] async fn test_poison_hangup() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); messenger.set_version_ranges(HashMap::from([( ApiKey::ListOffsets, ListOffsetsRequest::API_VERSION_RANGE, @@ -1127,7 +1154,7 @@ mod tests { #[tokio::test] async fn test_poison_negative_message_size() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); messenger.set_version_ranges(HashMap::from([( ApiKey::ListOffsets, ListOffsetsRequest::API_VERSION_RANGE, @@ -1160,7 +1187,7 @@ mod tests { #[tokio::test] async fn test_broken_msg_header_does_not_poison() { let (sim, rx) = MessageSimulator::new(); - let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); messenger.set_version_ranges(HashMap::from([( ApiKey::ApiVersions, ApiVersionsRequest::API_VERSION_RANGE, @@ -1206,7 +1233,7 @@ mod tests { let (tx_front, rx_middle) = tokio::io::duplex(1); let (tx_middle, mut rx_back) = tokio::io::duplex(1); - let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID)); + let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID), None); // create two barriers: // - pause: will be passed after 3 bytes were sent by the client @@ -1352,6 +1379,31 @@ mod tests { handle_network.abort(); } + #[tokio::test] + async fn test_request_timeout() { + let (tx, _rx) = tokio::io::duplex(1_000); + let mut messenger = Messenger::new( + tx, + 1_000, + Arc::from(DEFAULT_CLIENT_ID), + Some(Duration::from_millis(200)), + ); + messenger.set_version_ranges(HashMap::from([( + ApiKey::ApiVersions, + ApiVersionsRequest::API_VERSION_RANGE, + )])); + + let err = messenger + .request(ApiVersionsRequest { + client_software_name: Some(CompactString(String::from("foo"))), + client_software_version: Some(CompactString(String::from("bar"))), + tagged_fields: Some(TaggedFields::default()), + }) + .await + .unwrap_err(); + assert_matches!(err, RequestError::IO(e) if e.kind() == std::io::ErrorKind::TimedOut); + } + #[derive(Debug)] enum Message { Send(Vec),