@@ -2,6 +2,7 @@ use anyhow::{Context, anyhow};
2
2
use reqwest:: { IntoUrl , Response , Url } ;
3
3
use semver:: Version ;
4
4
use slog:: { Logger , error, warn} ;
5
+ use std:: time:: Duration ;
5
6
6
7
use mithril_common:: MITHRIL_API_VERSION_HEADER ;
7
8
use mithril_common:: api_version:: APIVersionProvider ;
@@ -16,6 +17,7 @@ const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incom
16
17
pub struct AggregatorClient {
17
18
pub ( super ) aggregator_endpoint : Url ,
18
19
pub ( super ) api_version_provider : APIVersionProvider ,
20
+ pub ( super ) timeout_duration : Option < Duration > ,
19
21
pub ( super ) client : reqwest:: Client ,
20
22
pub ( super ) logger : Logger ,
21
23
}
@@ -41,6 +43,10 @@ impl AggregatorClient {
41
43
request_builder = request_builder. json ( & body) ;
42
44
}
43
45
46
+ if let Some ( timeout) = self . timeout_duration {
47
+ request_builder = request_builder. timeout ( timeout) ;
48
+ }
49
+
44
50
match request_builder. send ( ) . await {
45
51
Ok ( response) => {
46
52
self . warn_if_api_version_mismatch ( & response) ;
@@ -159,6 +165,15 @@ mod tests {
159
165
chu : u8 ,
160
166
}
161
167
168
+ impl TestBody {
169
+ fn new < P : Into < String > > ( pika : P , chu : u8 ) -> Self {
170
+ Self {
171
+ pika : pika. into ( ) ,
172
+ chu,
173
+ }
174
+ }
175
+ }
176
+
162
177
struct TestPostQuery {
163
178
body : TestBody ,
164
179
}
@@ -226,6 +241,23 @@ mod tests {
226
241
227
242
client. send ( TestGetQuery ) . await . expect ( "should not fail" ) ;
228
243
}
244
+
245
+ #[ tokio:: test]
246
+ async fn test_get_query_timeout ( ) {
247
+ let ( server, mut client) = setup_server_and_client ( ) ;
248
+ client. timeout_duration = Some ( Duration :: from_millis ( 10 ) ) ;
249
+ let _server_mock = server. mock ( |when, then| {
250
+ when. method ( httpmock:: Method :: GET ) ;
251
+ then. delay ( Duration :: from_millis ( 100 ) ) ;
252
+ } ) ;
253
+
254
+ let error = client. send ( TestGetQuery ) . await . expect_err ( "should not fail" ) ;
255
+
256
+ assert ! (
257
+ matches!( error, AggregatorClientError :: RemoteServerUnreachable ( _) ) ,
258
+ "unexpected error type: {error:?}"
259
+ ) ;
260
+ }
229
261
}
230
262
231
263
mod post {
@@ -238,22 +270,13 @@ mod tests {
238
270
when. method ( httpmock:: Method :: POST )
239
271
. path ( "/dummy-post-route" )
240
272
. header ( "content-type" , "application/json" )
241
- . body (
242
- serde_json:: to_string ( & TestBody {
243
- pika : "miaouss" . to_string ( ) ,
244
- chu : 5 ,
245
- } )
246
- . unwrap ( ) ,
247
- ) ;
273
+ . body ( serde_json:: to_string ( & TestBody :: new ( "miaouss" , 5 ) ) . unwrap ( ) ) ;
248
274
then. status ( 201 ) ;
249
275
} ) ;
250
276
251
277
let response = client
252
278
. send ( TestPostQuery {
253
- body : TestBody {
254
- pika : "miaouss" . to_string ( ) ,
255
- chu : 5 ,
256
- } ,
279
+ body : TestBody :: new ( "miaouss" , 5 ) ,
257
280
} )
258
281
. await
259
282
. unwrap ( ) ;
@@ -274,14 +297,33 @@ mod tests {
274
297
275
298
client
276
299
. send ( TestPostQuery {
277
- body : TestBody {
278
- pika : "a" . to_string ( ) ,
279
- chu : 3 ,
280
- } ,
300
+ body : TestBody :: new ( "miaouss" , 3 ) ,
281
301
} )
282
302
. await
283
303
. expect ( "should not fail" ) ;
284
304
}
305
+
306
+ #[ tokio:: test]
307
+ async fn test_post_query_timeout ( ) {
308
+ let ( server, mut client) = setup_server_and_client ( ) ;
309
+ client. timeout_duration = Some ( Duration :: from_millis ( 10 ) ) ;
310
+ let _server_mock = server. mock ( |when, then| {
311
+ when. method ( httpmock:: Method :: POST ) ;
312
+ then. delay ( Duration :: from_millis ( 100 ) ) ;
313
+ } ) ;
314
+
315
+ let error = client
316
+ . send ( TestPostQuery {
317
+ body : TestBody :: new ( "miaouss" , 3 ) ,
318
+ } )
319
+ . await
320
+ . expect_err ( "should not fail" ) ;
321
+
322
+ assert ! (
323
+ matches!( error, AggregatorClientError :: RemoteServerUnreachable ( _) ) ,
324
+ "unexpected error type: {error:?}"
325
+ ) ;
326
+ }
285
327
}
286
328
287
329
mod warn_if_api_version_mismatch {
@@ -481,10 +523,7 @@ mod tests {
481
523
482
524
client
483
525
. send ( TestPostQuery {
484
- body : TestBody {
485
- pika : "miaouss" . to_string ( ) ,
486
- chu : 5 ,
487
- } ,
526
+ body : TestBody :: new ( "miaouss" , 3 ) ,
488
527
} )
489
528
. await
490
529
. unwrap ( ) ;
0 commit comments