@@ -16,6 +16,7 @@ use rsasl::{
1616 mechname:: MechanismNameError ,
1717 prelude:: { Mechname , SASLError , SessionError } ,
1818} ;
19+ use std:: time:: Duration ;
1920use thiserror:: Error ;
2021use tokio:: {
2122 io:: { AsyncRead , AsyncWrite , AsyncWriteExt , WriteHalf } ,
@@ -133,6 +134,12 @@ pub struct Messenger<RW> {
133134
134135 /// Join handle for the background worker that fetches responses.
135136 join_handle : JoinHandle < ( ) > ,
137+
138+ /// Timeout for requests.
139+ ///
140+ /// If set, requests will timeout after the given duration.
141+ /// If not set, requests will not timeout.
142+ timeout : Option < Duration > ,
136143}
137144
138145#[ derive( Error , Debug ) ]
@@ -218,7 +225,12 @@ impl<RW> Messenger<RW>
218225where
219226 RW : AsyncRead + AsyncWrite + Send + ' static ,
220227{
221- pub fn new ( stream : RW , max_message_size : usize , client_id : Arc < str > ) -> Self {
228+ pub fn new (
229+ stream : RW ,
230+ max_message_size : usize ,
231+ client_id : Arc < str > ,
232+ timeout : Option < Duration > ,
233+ ) -> Self {
222234 let ( stream_read, stream_write) = tokio:: io:: split ( stream) ;
223235 let state = Arc :: new ( Mutex :: new ( MessengerState :: RequestMap ( HashMap :: default ( ) ) ) ) ;
224236 let state_captured = Arc :: clone ( & state) ;
@@ -304,6 +316,7 @@ where
304316 version_ranges : HashMap :: new ( ) ,
305317 state,
306318 join_handle,
319+ timeout,
307320 }
308321 }
309322
@@ -404,7 +417,21 @@ where
404417 self . send_message ( buf) . await ?;
405418 cleanup_on_cancel. message_sent ( ) ;
406419
407- let mut response = rx. await . expect ( "Who closed this channel?!" ) ?;
420+ let mut response = if let Some ( timeout) = self . timeout {
421+ // If a request times out, return a `RequestError::IO` with a timeout error.
422+ // This allows the backoff mechanism to detect transport issues and re-establish the connection as needed.
423+ //
424+ // Typically, timeouts occur due to abrupt TCP connection loss (e.g., a disconnected cable).
425+ tokio:: time:: timeout ( timeout, rx) . await . map_err ( |_| {
426+ RequestError :: IO ( std:: io:: Error :: new (
427+ std:: io:: ErrorKind :: TimedOut ,
428+ "Request timed out" ,
429+ ) )
430+ } ) ?
431+ } else {
432+ rx. await
433+ }
434+ . expect ( "Who closed this channel?!" ) ?;
408435 let body = R :: ResponseBody :: read_versioned ( & mut response. data , body_api_version) ?;
409436
410437 // check if we fully consumed the message, otherwise there might be a bug in our protocol code
@@ -818,7 +845,7 @@ mod tests {
818845 #[ tokio:: test]
819846 async fn test_sync_versions_ok ( ) {
820847 let ( sim, rx) = MessageSimulator :: new ( ) ;
821- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
848+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
822849
823850 // construct response
824851 let mut msg = vec ! [ ] ;
@@ -855,7 +882,7 @@ mod tests {
855882 #[ tokio:: test]
856883 async fn test_sync_versions_ignores_error_code ( ) {
857884 let ( sim, rx) = MessageSimulator :: new ( ) ;
858- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
885+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
859886
860887 // construct error response
861888 let mut msg = vec ! [ ] ;
@@ -918,7 +945,7 @@ mod tests {
918945 #[ tokio:: test]
919946 async fn test_sync_versions_ignores_read_code ( ) {
920947 let ( sim, rx) = MessageSimulator :: new ( ) ;
921- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
948+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
922949
923950 // construct error response
924951 let mut msg = vec ! [ ] ;
@@ -969,7 +996,7 @@ mod tests {
969996 #[ tokio:: test]
970997 async fn test_sync_versions_err_flipped_range ( ) {
971998 let ( sim, rx) = MessageSimulator :: new ( ) ;
972- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
999+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
9731000
9741001 // construct response
9751002 let mut msg = vec ! [ ] ;
@@ -1002,7 +1029,7 @@ mod tests {
10021029 #[ tokio:: test]
10031030 async fn test_sync_versions_ignores_garbage ( ) {
10041031 let ( sim, rx) = MessageSimulator :: new ( ) ;
1005- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1032+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
10061033
10071034 // construct response
10081035 let mut msg = vec ! [ ] ;
@@ -1066,7 +1093,7 @@ mod tests {
10661093 #[ tokio:: test]
10671094 async fn test_sync_versions_err_no_working_version ( ) {
10681095 let ( sim, rx) = MessageSimulator :: new ( ) ;
1069- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1096+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
10701097
10711098 // construct error response
10721099 for ( i, v) in ( ( ApiVersionsRequest :: API_VERSION_RANGE . min ( ) . 0 . 0 )
@@ -1105,7 +1132,7 @@ mod tests {
11051132 #[ tokio:: test]
11061133 async fn test_poison_hangup ( ) {
11071134 let ( sim, rx) = MessageSimulator :: new ( ) ;
1108- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1135+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
11091136 messenger. set_version_ranges ( HashMap :: from ( [ (
11101137 ApiKey :: ListOffsets ,
11111138 ListOffsetsRequest :: API_VERSION_RANGE ,
@@ -1127,7 +1154,7 @@ mod tests {
11271154 #[ tokio:: test]
11281155 async fn test_poison_negative_message_size ( ) {
11291156 let ( sim, rx) = MessageSimulator :: new ( ) ;
1130- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1157+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
11311158 messenger. set_version_ranges ( HashMap :: from ( [ (
11321159 ApiKey :: ListOffsets ,
11331160 ListOffsetsRequest :: API_VERSION_RANGE ,
@@ -1160,7 +1187,7 @@ mod tests {
11601187 #[ tokio:: test]
11611188 async fn test_broken_msg_header_does_not_poison ( ) {
11621189 let ( sim, rx) = MessageSimulator :: new ( ) ;
1163- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1190+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
11641191 messenger. set_version_ranges ( HashMap :: from ( [ (
11651192 ApiKey :: ApiVersions ,
11661193 ApiVersionsRequest :: API_VERSION_RANGE ,
@@ -1206,7 +1233,7 @@ mod tests {
12061233 let ( tx_front, rx_middle) = tokio:: io:: duplex ( 1 ) ;
12071234 let ( tx_middle, mut rx_back) = tokio:: io:: duplex ( 1 ) ;
12081235
1209- let mut messenger = Messenger :: new ( tx_front, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1236+ let mut messenger = Messenger :: new ( tx_front, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
12101237
12111238 // create two barriers:
12121239 // - pause: will be passed after 3 bytes were sent by the client
@@ -1352,6 +1379,31 @@ mod tests {
13521379 handle_network. abort ( ) ;
13531380 }
13541381
1382+ #[ tokio:: test]
1383+ async fn test_request_timeout ( ) {
1384+ let ( tx, _rx) = tokio:: io:: duplex ( 1_000 ) ;
1385+ let mut messenger = Messenger :: new (
1386+ tx,
1387+ 1_000 ,
1388+ Arc :: from ( DEFAULT_CLIENT_ID ) ,
1389+ Some ( Duration :: from_millis ( 200 ) ) ,
1390+ ) ;
1391+ messenger. set_version_ranges ( HashMap :: from ( [ (
1392+ ApiKey :: ApiVersions ,
1393+ ApiVersionsRequest :: API_VERSION_RANGE ,
1394+ ) ] ) ) ;
1395+
1396+ let err = messenger
1397+ . request ( ApiVersionsRequest {
1398+ client_software_name : Some ( CompactString ( String :: from ( "foo" ) ) ) ,
1399+ client_software_version : Some ( CompactString ( String :: from ( "bar" ) ) ) ,
1400+ tagged_fields : Some ( TaggedFields :: default ( ) ) ,
1401+ } )
1402+ . await
1403+ . unwrap_err ( ) ;
1404+ assert_matches ! ( err, RequestError :: IO ( e) if e. kind( ) == std:: io:: ErrorKind :: TimedOut ) ;
1405+ }
1406+
13551407 #[ derive( Debug ) ]
13561408 enum Message {
13571409 Send ( Vec < u8 > ) ,
0 commit comments