Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fuzz/fuzz_targets/protocol_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ where
let transport = MockTransport::new(transport_data);

// setup messenger
let mut messenger = Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger =
Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID), None);
messenger.override_version_ranges(HashMap::from([(
api_key,
ApiVersionRange::new(api_version, api_version),
Expand Down
12 changes: 12 additions & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub struct ClientBuilder {
sasl_config: Option<SaslConfig>,
backoff_config: Arc<BackoffConfig>,
connect_timeout: Option<Duration>,
timeout: Option<Duration>,
}

impl ClientBuilder {
Expand All @@ -63,6 +64,7 @@ impl ClientBuilder {
sasl_config: None,
backoff_config: Default::default(),
connect_timeout: Some(Duration::from_secs(30)),
timeout: None,
}
}

Expand Down Expand Up @@ -117,6 +119,15 @@ impl ClientBuilder {
self
}

/// Set the timeout on requests to the broker.
/// By setting this to `None`, requests will never time out unless
/// interrupted by an external event.
/// The default timeout is `None`.
pub fn timeout(mut self, timeout: Option<Duration>) -> Self {
self.timeout = timeout;
self
}

/// Build [`Client`].
pub async fn build(self) -> Result<Client> {
let brokers = Arc::new(BrokerConnector::new(
Expand All @@ -129,6 +140,7 @@ impl ClientBuilder {
self.max_message_size,
Arc::clone(&self.backoff_config),
self.connect_timeout,
self.timeout,
));
brokers.refresh_metadata().await?;

Expand Down
26 changes: 24 additions & 2 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,15 @@ impl Display for MultiError {
trait ConnectionHandler {
type R: RequestHandler + Send + Sync;

#[allow(clippy::too_many_arguments, reason = "Method is internal")]
fn connect(
&self,
client_id: Arc<str>,
tls_config: TlsConfig,
socks5_proxy: Option<String>,
sasl_config: Option<SaslConfig>,
max_message_size: usize,
connection_timeout: Option<Duration>,
timeout: Option<Duration>,
) -> impl Future<Output = Result<Arc<Self::R>>> + Send;
}
Expand Down Expand Up @@ -150,6 +152,7 @@ impl ConnectionHandler for BrokerRepresentation {
socks5_proxy: Option<String>,
sasl_config: Option<SaslConfig>,
max_message_size: usize,
connection_timeout: Option<Duration>,
timeout: Option<Duration>,
) -> Result<Arc<Self::R>> {
let url = self.url();
Expand All @@ -158,14 +161,19 @@ impl ConnectionHandler for BrokerRepresentation {
url = url.as_str(),
"Establishing new connection",
);
let transport = Transport::connect(&url, tls_config, socks5_proxy, timeout)
let transport = Transport::connect(&url, tls_config, socks5_proxy, connection_timeout)
.await
.map_err(|error| Error::Transport {
broker: url.to_string(),
error,
})?;

let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id);
let mut messenger = Messenger::new(
BufStream::new(transport),
max_message_size,
client_id,
timeout,
);
messenger.sync_versions().await?;
if let Some(sasl_config) = sasl_config {
messenger.do_sasl(sasl_config).await?;
Expand Down Expand Up @@ -217,6 +225,12 @@ pub struct BrokerConnector {

/// Timeout for connection attempts to the broker.
connect_timeout: Option<Duration>,

/// Timeout for requests.
///
/// If set, requests will timeout after the given duration.
/// If not set, requests will not timeout.
timeout: Option<Duration>,
}

impl BrokerConnector {
Expand All @@ -230,6 +244,7 @@ impl BrokerConnector {
max_message_size: usize,
backoff_config: Arc<BackoffConfig>,
connect_timeout: Option<Duration>,
timeout: Option<Duration>,
) -> Self {
Self {
bootstrap_brokers,
Expand All @@ -243,6 +258,7 @@ impl BrokerConnector {
sasl_config,
max_message_size,
connect_timeout,
timeout,
}
}

Expand Down Expand Up @@ -340,6 +356,7 @@ impl BrokerConnector {
self.sasl_config.clone(),
self.max_message_size,
self.connect_timeout,
self.timeout,
)
.await?;
Ok(Some(connection))
Expand Down Expand Up @@ -455,6 +472,7 @@ impl BrokerCache for &BrokerConnector {
self.sasl_config.clone(),
self.max_message_size,
self.connect_timeout,
self.timeout,
)
.await?;

Expand Down Expand Up @@ -493,6 +511,7 @@ async fn connect_to_a_broker_with_retry<B>(
sasl_config: Option<SaslConfig>,
max_message_size: usize,
connect_timeout: Option<Duration>,
timeout: Option<Duration>,
) -> Result<Arc<B::R>>
where
B: ConnectionHandler + Send + Sync,
Expand All @@ -513,6 +532,7 @@ where
sasl_config.clone(),
max_message_size,
connect_timeout,
timeout,
)
.await;

Expand Down Expand Up @@ -825,6 +845,7 @@ mod tests {
_sasl_config: Option<SaslConfig>,
_max_message_size: usize,
_connect_timeout: Option<Duration>,
_timeout: Option<Duration>,
) -> Result<Arc<Self::R>> {
(self.conn)()
}
Expand Down Expand Up @@ -854,6 +875,7 @@ mod tests {
Default::default(),
Default::default(),
None,
None,
)
.await
.unwrap();
Expand Down
76 changes: 64 additions & 12 deletions src/messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use rsasl::{
mechname::MechanismNameError,
prelude::{Mechname, SASLError, SessionError},
};
use std::time::Duration;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
Expand Down Expand Up @@ -133,6 +134,12 @@ pub struct Messenger<RW> {

/// Join handle for the background worker that fetches responses.
join_handle: JoinHandle<()>,

/// Timeout for requests.
///
/// If set, requests will timeout after the given duration.
/// If not set, requests will not timeout.
timeout: Option<Duration>,
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -218,7 +225,12 @@ impl<RW> Messenger<RW>
where
RW: AsyncRead + AsyncWrite + Send + 'static,
{
pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
pub fn new(
stream: RW,
max_message_size: usize,
client_id: Arc<str>,
timeout: Option<Duration>,
) -> Self {
let (stream_read, stream_write) = tokio::io::split(stream);
let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default())));
let state_captured = Arc::clone(&state);
Expand Down Expand Up @@ -304,6 +316,7 @@ where
version_ranges: HashMap::new(),
state,
join_handle,
timeout,
}
}

Expand Down Expand Up @@ -404,7 +417,21 @@ where
self.send_message(buf).await?;
cleanup_on_cancel.message_sent();

let mut response = rx.await.expect("Who closed this channel?!")?;
let mut response = if let Some(timeout) = self.timeout {
// If a request times out, return a `RequestError::IO` with a timeout error.
// This allows the backoff mechanism to detect transport issues and re-establish the connection as needed.
//
// Typically, timeouts occur due to abrupt TCP connection loss (e.g., a disconnected cable).
tokio::time::timeout(timeout, rx).await.map_err(|_| {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a code comment that captures #285 (comment) here? That would be great for future maintainers & contributors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved in 2c88e16

RequestError::IO(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Request timed out",
))
})?
} else {
rx.await
}
.expect("Who closed this channel?!")?;
let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?;

// check if we fully consumed the message, otherwise there might be a bug in our protocol code
Expand Down Expand Up @@ -818,7 +845,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ok() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -855,7 +882,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ignores_error_code() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);

// construct error response
let mut msg = vec![];
Expand Down Expand Up @@ -918,7 +945,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ignores_read_code() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);

// construct error response
let mut msg = vec![];
Expand Down Expand Up @@ -969,7 +996,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_err_flipped_range() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -1002,7 +1029,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_ignores_garbage() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);

// construct response
let mut msg = vec![];
Expand Down Expand Up @@ -1066,7 +1093,7 @@ mod tests {
#[tokio::test]
async fn test_sync_versions_err_no_working_version() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);

// construct error response
for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0.0)
Expand Down Expand Up @@ -1105,7 +1132,7 @@ mod tests {
#[tokio::test]
async fn test_poison_hangup() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
messenger.set_version_ranges(HashMap::from([(
ApiKey::ListOffsets,
ListOffsetsRequest::API_VERSION_RANGE,
Expand All @@ -1127,7 +1154,7 @@ mod tests {
#[tokio::test]
async fn test_poison_negative_message_size() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
messenger.set_version_ranges(HashMap::from([(
ApiKey::ListOffsets,
ListOffsetsRequest::API_VERSION_RANGE,
Expand Down Expand Up @@ -1160,7 +1187,7 @@ mod tests {
#[tokio::test]
async fn test_broken_msg_header_does_not_poison() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
messenger.set_version_ranges(HashMap::from([(
ApiKey::ApiVersions,
ApiVersionsRequest::API_VERSION_RANGE,
Expand Down Expand Up @@ -1206,7 +1233,7 @@ mod tests {
let (tx_front, rx_middle) = tokio::io::duplex(1);
let (tx_middle, mut rx_back) = tokio::io::duplex(1);

let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);

// create two barriers:
// - pause: will be passed after 3 bytes were sent by the client
Expand Down Expand Up @@ -1352,6 +1379,31 @@ mod tests {
handle_network.abort();
}

#[tokio::test]
async fn test_request_timeout() {
let (tx, _rx) = tokio::io::duplex(1_000);
let mut messenger = Messenger::new(
tx,
1_000,
Arc::from(DEFAULT_CLIENT_ID),
Some(Duration::from_millis(200)),
);
messenger.set_version_ranges(HashMap::from([(
ApiKey::ApiVersions,
ApiVersionsRequest::API_VERSION_RANGE,
)]));

let err = messenger
.request(ApiVersionsRequest {
client_software_name: Some(CompactString(String::from("foo"))),
client_software_version: Some(CompactString(String::from("bar"))),
tagged_fields: Some(TaggedFields::default()),
})
.await
.unwrap_err();
assert_matches!(err, RequestError::IO(e) if e.kind() == std::io::ErrorKind::TimedOut);
}

#[derive(Debug)]
enum Message {
Send(Vec<u8>),
Expand Down