@@ -16,6 +16,7 @@ use rsasl::{
1616 mechname:: MechanismNameError ,
1717 prelude:: { Mechname , SASLError , SessionError } ,
1818} ;
19+ use std:: time:: Duration ;
1920use thiserror:: Error ;
2021use tokio:: {
2122 io:: { AsyncRead , AsyncWrite , AsyncWriteExt , WriteHalf } ,
@@ -133,6 +134,12 @@ pub struct Messenger<RW> {
133134
134135 /// Join handle for the background worker that fetches responses.
135136 join_handle : JoinHandle < ( ) > ,
137+
138+ /// Timeout for requests.
139+ ///
140+ /// If set, requests will timeout after the given duration.
141+ /// If not set, requests will not timeout.
142+ timeout : Option < Duration > ,
136143}
137144
138145#[ derive( Error , Debug ) ]
@@ -218,7 +225,12 @@ impl<RW> Messenger<RW>
218225where
219226 RW : AsyncRead + AsyncWrite + Send + ' static ,
220227{
221- pub fn new ( stream : RW , max_message_size : usize , client_id : Arc < str > ) -> Self {
228+ pub fn new (
229+ stream : RW ,
230+ max_message_size : usize ,
231+ client_id : Arc < str > ,
232+ timeout : Option < Duration > ,
233+ ) -> Self {
222234 let ( stream_read, stream_write) = tokio:: io:: split ( stream) ;
223235 let state = Arc :: new ( Mutex :: new ( MessengerState :: RequestMap ( HashMap :: default ( ) ) ) ) ;
224236 let state_captured = Arc :: clone ( & state) ;
@@ -304,6 +316,7 @@ where
304316 version_ranges : HashMap :: new ( ) ,
305317 state,
306318 join_handle,
319+ timeout,
307320 }
308321 }
309322
@@ -404,7 +417,17 @@ where
404417 self . send_message ( buf) . await ?;
405418 cleanup_on_cancel. message_sent ( ) ;
406419
407- let mut response = rx. await . expect ( "Who closed this channel?!" ) ?;
420+ let mut response = if let Some ( timeout) = self . timeout {
421+ tokio:: time:: timeout ( timeout, rx) . await . map_err ( |_| {
422+ RequestError :: IO ( std:: io:: Error :: new (
423+ std:: io:: ErrorKind :: TimedOut ,
424+ "Request timed out" ,
425+ ) )
426+ } ) ?
427+ } else {
428+ rx. await
429+ }
430+ . expect ( "Who closed this channel?!" ) ?;
408431 let body = R :: ResponseBody :: read_versioned ( & mut response. data , body_api_version) ?;
409432
410433 // check if we fully consumed the message, otherwise there might be a bug in our protocol code
@@ -818,7 +841,7 @@ mod tests {
818841 #[ tokio:: test]
819842 async fn test_sync_versions_ok ( ) {
820843 let ( sim, rx) = MessageSimulator :: new ( ) ;
821- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
844+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
822845
823846 // construct response
824847 let mut msg = vec ! [ ] ;
@@ -855,7 +878,7 @@ mod tests {
855878 #[ tokio:: test]
856879 async fn test_sync_versions_ignores_error_code ( ) {
857880 let ( sim, rx) = MessageSimulator :: new ( ) ;
858- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
881+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
859882
860883 // construct error response
861884 let mut msg = vec ! [ ] ;
@@ -918,7 +941,7 @@ mod tests {
918941 #[ tokio:: test]
919942 async fn test_sync_versions_ignores_read_code ( ) {
920943 let ( sim, rx) = MessageSimulator :: new ( ) ;
921- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
944+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
922945
923946 // construct error response
924947 let mut msg = vec ! [ ] ;
@@ -969,7 +992,7 @@ mod tests {
969992 #[ tokio:: test]
970993 async fn test_sync_versions_err_flipped_range ( ) {
971994 let ( sim, rx) = MessageSimulator :: new ( ) ;
972- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
995+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
973996
974997 // construct response
975998 let mut msg = vec ! [ ] ;
@@ -1002,7 +1025,7 @@ mod tests {
10021025 #[ tokio:: test]
10031026 async fn test_sync_versions_ignores_garbage ( ) {
10041027 let ( sim, rx) = MessageSimulator :: new ( ) ;
1005- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1028+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
10061029
10071030 // construct response
10081031 let mut msg = vec ! [ ] ;
@@ -1066,7 +1089,7 @@ mod tests {
10661089 #[ tokio:: test]
10671090 async fn test_sync_versions_err_no_working_version ( ) {
10681091 let ( sim, rx) = MessageSimulator :: new ( ) ;
1069- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1092+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
10701093
10711094 // construct error response
10721095 for ( i, v) in ( ( ApiVersionsRequest :: API_VERSION_RANGE . min ( ) . 0 . 0 )
@@ -1105,7 +1128,7 @@ mod tests {
11051128 #[ tokio:: test]
11061129 async fn test_poison_hangup ( ) {
11071130 let ( sim, rx) = MessageSimulator :: new ( ) ;
1108- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1131+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
11091132 messenger. set_version_ranges ( HashMap :: from ( [ (
11101133 ApiKey :: ListOffsets ,
11111134 ListOffsetsRequest :: API_VERSION_RANGE ,
@@ -1127,7 +1150,7 @@ mod tests {
11271150 #[ tokio:: test]
11281151 async fn test_poison_negative_message_size ( ) {
11291152 let ( sim, rx) = MessageSimulator :: new ( ) ;
1130- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1153+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
11311154 messenger. set_version_ranges ( HashMap :: from ( [ (
11321155 ApiKey :: ListOffsets ,
11331156 ListOffsetsRequest :: API_VERSION_RANGE ,
@@ -1160,7 +1183,7 @@ mod tests {
11601183 #[ tokio:: test]
11611184 async fn test_broken_msg_header_does_not_poison ( ) {
11621185 let ( sim, rx) = MessageSimulator :: new ( ) ;
1163- let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1186+ let mut messenger = Messenger :: new ( rx, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
11641187 messenger. set_version_ranges ( HashMap :: from ( [ (
11651188 ApiKey :: ApiVersions ,
11661189 ApiVersionsRequest :: API_VERSION_RANGE ,
@@ -1206,7 +1229,7 @@ mod tests {
12061229 let ( tx_front, rx_middle) = tokio:: io:: duplex ( 1 ) ;
12071230 let ( tx_middle, mut rx_back) = tokio:: io:: duplex ( 1 ) ;
12081231
1209- let mut messenger = Messenger :: new ( tx_front, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) ) ;
1232+ let mut messenger = Messenger :: new ( tx_front, 1_000 , Arc :: from ( DEFAULT_CLIENT_ID ) , None ) ;
12101233
12111234 // create two barriers:
12121235 // - pause: will be passed after 3 bytes were sent by the client
0 commit comments