Skip to content

Commit 3a414b7

Browse files
authored
feat: add request level timeout (#287)
* feat: add request level timeout Signed-off-by: WenyXu <wenymedia@gmail.com> * chore: add comments Signed-off-by: WenyXu <wenymedia@gmail.com> * test: test request timeout Signed-off-by: WenyXu <wenymedia@gmail.com> --------- Signed-off-by: WenyXu <wenymedia@gmail.com>
1 parent fe393e5 commit 3a414b7

File tree

4 files changed

+102
-15
lines changed

4 files changed

+102
-15
lines changed

fuzz/fuzz_targets/protocol_reader.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ where
143143
let transport = MockTransport::new(transport_data);
144144

145145
// setup messenger
146-
let mut messenger = Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID));
146+
let mut messenger =
147+
Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID), None);
147148
messenger.override_version_ranges(HashMap::from([(
148149
api_key,
149150
ApiVersionRange::new(api_version, api_version),

src/client/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub struct ClientBuilder {
4949
sasl_config: Option<SaslConfig>,
5050
backoff_config: Arc<BackoffConfig>,
5151
connect_timeout: Option<Duration>,
52+
timeout: Option<Duration>,
5253
}
5354

5455
impl ClientBuilder {
@@ -63,6 +64,7 @@ impl ClientBuilder {
6364
sasl_config: None,
6465
backoff_config: Default::default(),
6566
connect_timeout: Some(Duration::from_secs(30)),
67+
timeout: None,
6668
}
6769
}
6870

@@ -117,6 +119,15 @@ impl ClientBuilder {
117119
self
118120
}
119121

122+
/// Set the timeout on requests to the broker.
123+
/// By setting this to `None`, requests will never time out unless
124+
/// interrupted by an external event.
125+
/// The default timeout is `None`.
126+
pub fn timeout(mut self, timeout: Option<Duration>) -> Self {
127+
self.timeout = timeout;
128+
self
129+
}
130+
120131
/// Build [`Client`].
121132
pub async fn build(self) -> Result<Client> {
122133
let brokers = Arc::new(BrokerConnector::new(
@@ -129,6 +140,7 @@ impl ClientBuilder {
129140
self.max_message_size,
130141
Arc::clone(&self.backoff_config),
131142
self.connect_timeout,
143+
self.timeout,
132144
));
133145
brokers.refresh_metadata().await?;
134146

src/connection.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,15 @@ impl Display for MultiError {
8080
trait ConnectionHandler {
8181
type R: RequestHandler + Send + Sync;
8282

83+
#[allow(clippy::too_many_arguments, reason = "Method is internal")]
8384
fn connect(
8485
&self,
8586
client_id: Arc<str>,
8687
tls_config: TlsConfig,
8788
socks5_proxy: Option<String>,
8889
sasl_config: Option<SaslConfig>,
8990
max_message_size: usize,
91+
connection_timeout: Option<Duration>,
9092
timeout: Option<Duration>,
9193
) -> impl Future<Output = Result<Arc<Self::R>>> + Send;
9294
}
@@ -150,6 +152,7 @@ impl ConnectionHandler for BrokerRepresentation {
150152
socks5_proxy: Option<String>,
151153
sasl_config: Option<SaslConfig>,
152154
max_message_size: usize,
155+
connection_timeout: Option<Duration>,
153156
timeout: Option<Duration>,
154157
) -> Result<Arc<Self::R>> {
155158
let url = self.url();
@@ -158,14 +161,19 @@ impl ConnectionHandler for BrokerRepresentation {
158161
url = url.as_str(),
159162
"Establishing new connection",
160163
);
161-
let transport = Transport::connect(&url, tls_config, socks5_proxy, timeout)
164+
let transport = Transport::connect(&url, tls_config, socks5_proxy, connection_timeout)
162165
.await
163166
.map_err(|error| Error::Transport {
164167
broker: url.to_string(),
165168
error,
166169
})?;
167170

168-
let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id);
171+
let mut messenger = Messenger::new(
172+
BufStream::new(transport),
173+
max_message_size,
174+
client_id,
175+
timeout,
176+
);
169177
messenger.sync_versions().await?;
170178
if let Some(sasl_config) = sasl_config {
171179
messenger.do_sasl(sasl_config).await?;
@@ -217,6 +225,12 @@ pub struct BrokerConnector {
217225

218226
/// Timeout for connection attempts to the broker.
219227
connect_timeout: Option<Duration>,
228+
229+
/// Timeout for requests.
230+
///
231+
/// If set, requests will timeout after the given duration.
232+
/// If not set, requests will not timeout.
233+
timeout: Option<Duration>,
220234
}
221235

222236
impl BrokerConnector {
@@ -230,6 +244,7 @@ impl BrokerConnector {
230244
max_message_size: usize,
231245
backoff_config: Arc<BackoffConfig>,
232246
connect_timeout: Option<Duration>,
247+
timeout: Option<Duration>,
233248
) -> Self {
234249
Self {
235250
bootstrap_brokers,
@@ -243,6 +258,7 @@ impl BrokerConnector {
243258
sasl_config,
244259
max_message_size,
245260
connect_timeout,
261+
timeout,
246262
}
247263
}
248264

@@ -340,6 +356,7 @@ impl BrokerConnector {
340356
self.sasl_config.clone(),
341357
self.max_message_size,
342358
self.connect_timeout,
359+
self.timeout,
343360
)
344361
.await?;
345362
Ok(Some(connection))
@@ -455,6 +472,7 @@ impl BrokerCache for &BrokerConnector {
455472
self.sasl_config.clone(),
456473
self.max_message_size,
457474
self.connect_timeout,
475+
self.timeout,
458476
)
459477
.await?;
460478

@@ -493,6 +511,7 @@ async fn connect_to_a_broker_with_retry<B>(
493511
sasl_config: Option<SaslConfig>,
494512
max_message_size: usize,
495513
connect_timeout: Option<Duration>,
514+
timeout: Option<Duration>,
496515
) -> Result<Arc<B::R>>
497516
where
498517
B: ConnectionHandler + Send + Sync,
@@ -513,6 +532,7 @@ where
513532
sasl_config.clone(),
514533
max_message_size,
515534
connect_timeout,
535+
timeout,
516536
)
517537
.await;
518538

@@ -825,6 +845,7 @@ mod tests {
825845
_sasl_config: Option<SaslConfig>,
826846
_max_message_size: usize,
827847
_connect_timeout: Option<Duration>,
848+
_timeout: Option<Duration>,
828849
) -> Result<Arc<Self::R>> {
829850
(self.conn)()
830851
}
@@ -854,6 +875,7 @@ mod tests {
854875
Default::default(),
855876
Default::default(),
856877
None,
878+
None,
857879
)
858880
.await
859881
.unwrap();

src/messenger.rs

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use rsasl::{
1616
mechname::MechanismNameError,
1717
prelude::{Mechname, SASLError, SessionError},
1818
};
19+
use std::time::Duration;
1920
use thiserror::Error;
2021
use tokio::{
2122
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
@@ -133,6 +134,12 @@ pub struct Messenger<RW> {
133134

134135
/// Join handle for the background worker that fetches responses.
135136
join_handle: JoinHandle<()>,
137+
138+
/// Timeout for requests.
139+
///
140+
/// If set, requests will timeout after the given duration.
141+
/// If not set, requests will not timeout.
142+
timeout: Option<Duration>,
136143
}
137144

138145
#[derive(Error, Debug)]
@@ -218,7 +225,12 @@ impl<RW> Messenger<RW>
218225
where
219226
RW: AsyncRead + AsyncWrite + Send + 'static,
220227
{
221-
pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
228+
pub fn new(
229+
stream: RW,
230+
max_message_size: usize,
231+
client_id: Arc<str>,
232+
timeout: Option<Duration>,
233+
) -> Self {
222234
let (stream_read, stream_write) = tokio::io::split(stream);
223235
let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default())));
224236
let state_captured = Arc::clone(&state);
@@ -304,6 +316,7 @@ where
304316
version_ranges: HashMap::new(),
305317
state,
306318
join_handle,
319+
timeout,
307320
}
308321
}
309322

@@ -404,7 +417,21 @@ where
404417
self.send_message(buf).await?;
405418
cleanup_on_cancel.message_sent();
406419

407-
let mut response = rx.await.expect("Who closed this channel?!")?;
420+
let mut response = if let Some(timeout) = self.timeout {
421+
// If a request times out, return a `RequestError::IO` with a timeout error.
422+
// This allows the backoff mechanism to detect transport issues and re-establish the connection as needed.
423+
//
424+
// Typically, timeouts occur due to abrupt TCP connection loss (e.g., a disconnected cable).
425+
tokio::time::timeout(timeout, rx).await.map_err(|_| {
426+
RequestError::IO(std::io::Error::new(
427+
std::io::ErrorKind::TimedOut,
428+
"Request timed out",
429+
))
430+
})?
431+
} else {
432+
rx.await
433+
}
434+
.expect("Who closed this channel?!")?;
408435
let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?;
409436

410437
// check if we fully consumed the message, otherwise there might be a bug in our protocol code
@@ -818,7 +845,7 @@ mod tests {
818845
#[tokio::test]
819846
async fn test_sync_versions_ok() {
820847
let (sim, rx) = MessageSimulator::new();
821-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
848+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
822849

823850
// construct response
824851
let mut msg = vec![];
@@ -855,7 +882,7 @@ mod tests {
855882
#[tokio::test]
856883
async fn test_sync_versions_ignores_error_code() {
857884
let (sim, rx) = MessageSimulator::new();
858-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
885+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
859886

860887
// construct error response
861888
let mut msg = vec![];
@@ -918,7 +945,7 @@ mod tests {
918945
#[tokio::test]
919946
async fn test_sync_versions_ignores_read_code() {
920947
let (sim, rx) = MessageSimulator::new();
921-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
948+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
922949

923950
// construct error response
924951
let mut msg = vec![];
@@ -969,7 +996,7 @@ mod tests {
969996
#[tokio::test]
970997
async fn test_sync_versions_err_flipped_range() {
971998
let (sim, rx) = MessageSimulator::new();
972-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
999+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
9731000

9741001
// construct response
9751002
let mut msg = vec![];
@@ -1002,7 +1029,7 @@ mod tests {
10021029
#[tokio::test]
10031030
async fn test_sync_versions_ignores_garbage() {
10041031
let (sim, rx) = MessageSimulator::new();
1005-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1032+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
10061033

10071034
// construct response
10081035
let mut msg = vec![];
@@ -1066,7 +1093,7 @@ mod tests {
10661093
#[tokio::test]
10671094
async fn test_sync_versions_err_no_working_version() {
10681095
let (sim, rx) = MessageSimulator::new();
1069-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1096+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
10701097

10711098
// construct error response
10721099
for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0.0)
@@ -1105,7 +1132,7 @@ mod tests {
11051132
#[tokio::test]
11061133
async fn test_poison_hangup() {
11071134
let (sim, rx) = MessageSimulator::new();
1108-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1135+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
11091136
messenger.set_version_ranges(HashMap::from([(
11101137
ApiKey::ListOffsets,
11111138
ListOffsetsRequest::API_VERSION_RANGE,
@@ -1127,7 +1154,7 @@ mod tests {
11271154
#[tokio::test]
11281155
async fn test_poison_negative_message_size() {
11291156
let (sim, rx) = MessageSimulator::new();
1130-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1157+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
11311158
messenger.set_version_ranges(HashMap::from([(
11321159
ApiKey::ListOffsets,
11331160
ListOffsetsRequest::API_VERSION_RANGE,
@@ -1160,7 +1187,7 @@ mod tests {
11601187
#[tokio::test]
11611188
async fn test_broken_msg_header_does_not_poison() {
11621189
let (sim, rx) = MessageSimulator::new();
1163-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1190+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
11641191
messenger.set_version_ranges(HashMap::from([(
11651192
ApiKey::ApiVersions,
11661193
ApiVersionsRequest::API_VERSION_RANGE,
@@ -1206,7 +1233,7 @@ mod tests {
12061233
let (tx_front, rx_middle) = tokio::io::duplex(1);
12071234
let (tx_middle, mut rx_back) = tokio::io::duplex(1);
12081235

1209-
let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1236+
let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
12101237

12111238
// create two barriers:
12121239
// - pause: will be passed after 3 bytes were sent by the client
@@ -1352,6 +1379,31 @@ mod tests {
13521379
handle_network.abort();
13531380
}
13541381

1382+
#[tokio::test]
1383+
async fn test_request_timeout() {
1384+
let (tx, _rx) = tokio::io::duplex(1_000);
1385+
let mut messenger = Messenger::new(
1386+
tx,
1387+
1_000,
1388+
Arc::from(DEFAULT_CLIENT_ID),
1389+
Some(Duration::from_millis(200)),
1390+
);
1391+
messenger.set_version_ranges(HashMap::from([(
1392+
ApiKey::ApiVersions,
1393+
ApiVersionsRequest::API_VERSION_RANGE,
1394+
)]));
1395+
1396+
let err = messenger
1397+
.request(ApiVersionsRequest {
1398+
client_software_name: Some(CompactString(String::from("foo"))),
1399+
client_software_version: Some(CompactString(String::from("bar"))),
1400+
tagged_fields: Some(TaggedFields::default()),
1401+
})
1402+
.await
1403+
.unwrap_err();
1404+
assert_matches!(err, RequestError::IO(e) if e.kind() == std::io::ErrorKind::TimedOut);
1405+
}
1406+
13551407
#[derive(Debug)]
13561408
enum Message {
13571409
Send(Vec<u8>),

0 commit comments

Comments
 (0)