Skip to content

Commit da54e2c

Browse files
WenyXucrepererum
authored andcommitted
feat: add request level timeout
Signed-off-by: WenyXu <wenymedia@gmail.com>
1 parent fe393e5 commit da54e2c

File tree

4 files changed

+73
-15
lines changed

4 files changed

+73
-15
lines changed

fuzz/fuzz_targets/protocol_reader.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ where
143143
let transport = MockTransport::new(transport_data);
144144

145145
// setup messenger
146-
let mut messenger = Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID));
146+
let mut messenger =
147+
Messenger::new(transport, message_size, Arc::from(DEFAULT_CLIENT_ID), None);
147148
messenger.override_version_ranges(HashMap::from([(
148149
api_key,
149150
ApiVersionRange::new(api_version, api_version),

src/client/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub struct ClientBuilder {
4949
sasl_config: Option<SaslConfig>,
5050
backoff_config: Arc<BackoffConfig>,
5151
connect_timeout: Option<Duration>,
52+
timeout: Option<Duration>,
5253
}
5354

5455
impl ClientBuilder {
@@ -63,6 +64,7 @@ impl ClientBuilder {
6364
sasl_config: None,
6465
backoff_config: Default::default(),
6566
connect_timeout: Some(Duration::from_secs(30)),
67+
timeout: None,
6668
}
6769
}
6870

@@ -117,6 +119,15 @@ impl ClientBuilder {
117119
self
118120
}
119121

122+
/// Set the timeout on requests to the broker.
123+
/// By setting this to `None`, requests will never time out unless
124+
/// interrupted by an external event.
125+
/// The default timeout is `None`.
126+
pub fn timeout(mut self, timeout: Option<Duration>) -> Self {
127+
self.timeout = timeout;
128+
self
129+
}
130+
120131
/// Build [`Client`].
121132
pub async fn build(self) -> Result<Client> {
122133
let brokers = Arc::new(BrokerConnector::new(
@@ -129,6 +140,7 @@ impl ClientBuilder {
129140
self.max_message_size,
130141
Arc::clone(&self.backoff_config),
131142
self.connect_timeout,
143+
self.timeout,
132144
));
133145
brokers.refresh_metadata().await?;
134146

src/connection.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,15 @@ impl Display for MultiError {
8080
trait ConnectionHandler {
8181
type R: RequestHandler + Send + Sync;
8282

83+
#[allow(clippy::too_many_arguments, reason = "Method is internal")]
8384
fn connect(
8485
&self,
8586
client_id: Arc<str>,
8687
tls_config: TlsConfig,
8788
socks5_proxy: Option<String>,
8889
sasl_config: Option<SaslConfig>,
8990
max_message_size: usize,
91+
connection_timeout: Option<Duration>,
9092
timeout: Option<Duration>,
9193
) -> impl Future<Output = Result<Arc<Self::R>>> + Send;
9294
}
@@ -150,6 +152,7 @@ impl ConnectionHandler for BrokerRepresentation {
150152
socks5_proxy: Option<String>,
151153
sasl_config: Option<SaslConfig>,
152154
max_message_size: usize,
155+
connection_timeout: Option<Duration>,
153156
timeout: Option<Duration>,
154157
) -> Result<Arc<Self::R>> {
155158
let url = self.url();
@@ -158,14 +161,19 @@ impl ConnectionHandler for BrokerRepresentation {
158161
url = url.as_str(),
159162
"Establishing new connection",
160163
);
161-
let transport = Transport::connect(&url, tls_config, socks5_proxy, timeout)
164+
let transport = Transport::connect(&url, tls_config, socks5_proxy, connection_timeout)
162165
.await
163166
.map_err(|error| Error::Transport {
164167
broker: url.to_string(),
165168
error,
166169
})?;
167170

168-
let mut messenger = Messenger::new(BufStream::new(transport), max_message_size, client_id);
171+
let mut messenger = Messenger::new(
172+
BufStream::new(transport),
173+
max_message_size,
174+
client_id,
175+
timeout,
176+
);
169177
messenger.sync_versions().await?;
170178
if let Some(sasl_config) = sasl_config {
171179
messenger.do_sasl(sasl_config).await?;
@@ -217,6 +225,12 @@ pub struct BrokerConnector {
217225

218226
/// Timeout for connection attempts to the broker.
219227
connect_timeout: Option<Duration>,
228+
229+
/// Timeout for requests.
230+
///
231+
/// If set, requests will timeout after the given duration.
232+
/// If not set, requests will not timeout.
233+
timeout: Option<Duration>,
220234
}
221235

222236
impl BrokerConnector {
@@ -230,6 +244,7 @@ impl BrokerConnector {
230244
max_message_size: usize,
231245
backoff_config: Arc<BackoffConfig>,
232246
connect_timeout: Option<Duration>,
247+
timeout: Option<Duration>,
233248
) -> Self {
234249
Self {
235250
bootstrap_brokers,
@@ -243,6 +258,7 @@ impl BrokerConnector {
243258
sasl_config,
244259
max_message_size,
245260
connect_timeout,
261+
timeout,
246262
}
247263
}
248264

@@ -340,6 +356,7 @@ impl BrokerConnector {
340356
self.sasl_config.clone(),
341357
self.max_message_size,
342358
self.connect_timeout,
359+
self.timeout,
343360
)
344361
.await?;
345362
Ok(Some(connection))
@@ -455,6 +472,7 @@ impl BrokerCache for &BrokerConnector {
455472
self.sasl_config.clone(),
456473
self.max_message_size,
457474
self.connect_timeout,
475+
self.timeout,
458476
)
459477
.await?;
460478

@@ -493,6 +511,7 @@ async fn connect_to_a_broker_with_retry<B>(
493511
sasl_config: Option<SaslConfig>,
494512
max_message_size: usize,
495513
connect_timeout: Option<Duration>,
514+
timeout: Option<Duration>,
496515
) -> Result<Arc<B::R>>
497516
where
498517
B: ConnectionHandler + Send + Sync,
@@ -513,6 +532,7 @@ where
513532
sasl_config.clone(),
514533
max_message_size,
515534
connect_timeout,
535+
timeout,
516536
)
517537
.await;
518538

@@ -825,6 +845,7 @@ mod tests {
825845
_sasl_config: Option<SaslConfig>,
826846
_max_message_size: usize,
827847
_connect_timeout: Option<Duration>,
848+
_timeout: Option<Duration>,
828849
) -> Result<Arc<Self::R>> {
829850
(self.conn)()
830851
}
@@ -854,6 +875,7 @@ mod tests {
854875
Default::default(),
855876
Default::default(),
856877
None,
878+
None,
857879
)
858880
.await
859881
.unwrap();

src/messenger.rs

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use rsasl::{
1616
mechname::MechanismNameError,
1717
prelude::{Mechname, SASLError, SessionError},
1818
};
19+
use std::time::Duration;
1920
use thiserror::Error;
2021
use tokio::{
2122
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
@@ -133,6 +134,12 @@ pub struct Messenger<RW> {
133134

134135
/// Join handle for the background worker that fetches responses.
135136
join_handle: JoinHandle<()>,
137+
138+
/// Timeout for requests.
139+
///
140+
/// If set, requests will timeout after the given duration.
141+
/// If not set, requests will not timeout.
142+
timeout: Option<Duration>,
136143
}
137144

138145
#[derive(Error, Debug)]
@@ -218,7 +225,12 @@ impl<RW> Messenger<RW>
218225
where
219226
RW: AsyncRead + AsyncWrite + Send + 'static,
220227
{
221-
pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
228+
pub fn new(
229+
stream: RW,
230+
max_message_size: usize,
231+
client_id: Arc<str>,
232+
timeout: Option<Duration>,
233+
) -> Self {
222234
let (stream_read, stream_write) = tokio::io::split(stream);
223235
let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default())));
224236
let state_captured = Arc::clone(&state);
@@ -304,6 +316,7 @@ where
304316
version_ranges: HashMap::new(),
305317
state,
306318
join_handle,
319+
timeout,
307320
}
308321
}
309322

@@ -404,7 +417,17 @@ where
404417
self.send_message(buf).await?;
405418
cleanup_on_cancel.message_sent();
406419

407-
let mut response = rx.await.expect("Who closed this channel?!")?;
420+
let mut response = if let Some(timeout) = self.timeout {
421+
tokio::time::timeout(timeout, rx).await.map_err(|_| {
422+
RequestError::IO(std::io::Error::new(
423+
std::io::ErrorKind::TimedOut,
424+
"Request timed out",
425+
))
426+
})?
427+
} else {
428+
rx.await
429+
}
430+
.expect("Who closed this channel?!")?;
408431
let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?;
409432

410433
// check if we fully consumed the message, otherwise there might be a bug in our protocol code
@@ -818,7 +841,7 @@ mod tests {
818841
#[tokio::test]
819842
async fn test_sync_versions_ok() {
820843
let (sim, rx) = MessageSimulator::new();
821-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
844+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
822845

823846
// construct response
824847
let mut msg = vec![];
@@ -855,7 +878,7 @@ mod tests {
855878
#[tokio::test]
856879
async fn test_sync_versions_ignores_error_code() {
857880
let (sim, rx) = MessageSimulator::new();
858-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
881+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
859882

860883
// construct error response
861884
let mut msg = vec![];
@@ -918,7 +941,7 @@ mod tests {
918941
#[tokio::test]
919942
async fn test_sync_versions_ignores_read_code() {
920943
let (sim, rx) = MessageSimulator::new();
921-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
944+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
922945

923946
// construct error response
924947
let mut msg = vec![];
@@ -969,7 +992,7 @@ mod tests {
969992
#[tokio::test]
970993
async fn test_sync_versions_err_flipped_range() {
971994
let (sim, rx) = MessageSimulator::new();
972-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
995+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
973996

974997
// construct response
975998
let mut msg = vec![];
@@ -1002,7 +1025,7 @@ mod tests {
10021025
#[tokio::test]
10031026
async fn test_sync_versions_ignores_garbage() {
10041027
let (sim, rx) = MessageSimulator::new();
1005-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1028+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
10061029

10071030
// construct response
10081031
let mut msg = vec![];
@@ -1066,7 +1089,7 @@ mod tests {
10661089
#[tokio::test]
10671090
async fn test_sync_versions_err_no_working_version() {
10681091
let (sim, rx) = MessageSimulator::new();
1069-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1092+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
10701093

10711094
// construct error response
10721095
for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0.0)
@@ -1105,7 +1128,7 @@ mod tests {
11051128
#[tokio::test]
11061129
async fn test_poison_hangup() {
11071130
let (sim, rx) = MessageSimulator::new();
1108-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1131+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
11091132
messenger.set_version_ranges(HashMap::from([(
11101133
ApiKey::ListOffsets,
11111134
ListOffsetsRequest::API_VERSION_RANGE,
@@ -1127,7 +1150,7 @@ mod tests {
11271150
#[tokio::test]
11281151
async fn test_poison_negative_message_size() {
11291152
let (sim, rx) = MessageSimulator::new();
1130-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1153+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
11311154
messenger.set_version_ranges(HashMap::from([(
11321155
ApiKey::ListOffsets,
11331156
ListOffsetsRequest::API_VERSION_RANGE,
@@ -1160,7 +1183,7 @@ mod tests {
11601183
#[tokio::test]
11611184
async fn test_broken_msg_header_does_not_poison() {
11621185
let (sim, rx) = MessageSimulator::new();
1163-
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1186+
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
11641187
messenger.set_version_ranges(HashMap::from([(
11651188
ApiKey::ApiVersions,
11661189
ApiVersionsRequest::API_VERSION_RANGE,
@@ -1206,7 +1229,7 @@ mod tests {
12061229
let (tx_front, rx_middle) = tokio::io::duplex(1);
12071230
let (tx_middle, mut rx_back) = tokio::io::duplex(1);
12081231

1209-
let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1232+
let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID), None);
12101233

12111234
// create two barriers:
12121235
// - pause: will be passed after 3 bytes were sent by the client

0 commit comments

Comments
 (0)