1
1
use anyhow:: { Context , anyhow} ;
2
- use reqwest:: { IntoUrl , Response , Url } ;
2
+ use reqwest:: { IntoUrl , Response , Url , header :: HeaderMap } ;
3
3
use semver:: Version ;
4
4
use slog:: { Logger , error, warn} ;
5
5
use std:: time:: Duration ;
@@ -18,6 +18,7 @@ const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incom
18
18
pub struct AggregatorClient {
19
19
pub ( super ) aggregator_endpoint : Url ,
20
20
pub ( super ) api_version_provider : APIVersionProvider ,
21
+ pub ( super ) additional_headers : HeaderMap ,
21
22
pub ( super ) timeout_duration : Option < Duration > ,
22
23
pub ( super ) client : reqwest:: Client ,
23
24
pub ( super ) logger : Logger ,
@@ -39,6 +40,7 @@ impl AggregatorClient {
39
40
QueryMethod :: Get => self . client . get ( self . join_aggregator_endpoint ( & query. route ( ) ) ?) ,
40
41
QueryMethod :: Post => self . client . post ( self . join_aggregator_endpoint ( & query. route ( ) ) ?) ,
41
42
}
43
+ . headers ( self . additional_headers . clone ( ) )
42
44
. header ( MITHRIL_API_VERSION_HEADER , current_api_version. to_string ( ) ) ;
43
45
44
46
if let Some ( body) = query. body ( ) {
@@ -244,6 +246,29 @@ mod tests {
244
246
client. send ( TestGetQuery ) . await . expect ( "should not fail" ) ;
245
247
}
246
248
249
+ #[ tokio:: test]
250
+ async fn test_get_query_send_additional_header_and_dont_override_mithril_api_version_header ( )
251
+ {
252
+ let ( server, mut client) = setup_server_and_client ( ) ;
253
+ client. api_version_provider =
254
+ APIVersionProvider :: new_with_default_version ( Version :: parse ( "1.2.9" ) . unwrap ( ) ) ;
255
+ client. additional_headers = {
256
+ let mut headers = HeaderMap :: new ( ) ;
257
+ headers. insert ( MITHRIL_API_VERSION_HEADER , "9.4.5" . parse ( ) . unwrap ( ) ) ;
258
+ headers. insert ( "foo" , "bar" . parse ( ) . unwrap ( ) ) ;
259
+ headers
260
+ } ;
261
+
262
+ server. mock ( |when, then| {
263
+ when. method ( httpmock:: Method :: GET )
264
+ . header ( MITHRIL_API_VERSION_HEADER , "1.2.9" )
265
+ . header ( "foo" , "bar" ) ;
266
+ then. status ( 200 ) . body ( r#"{"foo": "a", "bar": 1}"# ) ;
267
+ } ) ;
268
+
269
+ client. send ( TestGetQuery ) . await . expect ( "should not fail" ) ;
270
+ }
271
+
247
272
#[ tokio:: test]
248
273
async fn test_get_query_timeout ( ) {
249
274
let ( server, mut client) = setup_server_and_client ( ) ;
@@ -276,14 +301,12 @@ mod tests {
276
301
then. status ( 201 ) ;
277
302
} ) ;
278
303
279
- let response = client
304
+ client
280
305
. send ( TestPostQuery {
281
306
body : TestBody :: new ( "miaouss" , 5 ) ,
282
307
} )
283
308
. await
284
309
. unwrap ( ) ;
285
-
286
- assert_eq ! ( response, ( ) )
287
310
}
288
311
289
312
#[ tokio:: test]
@@ -305,6 +328,34 @@ mod tests {
305
328
. expect ( "should not fail" ) ;
306
329
}
307
330
331
+ #[ tokio:: test]
332
+ async fn test_post_query_send_additional_header_and_dont_override_mithril_api_version_header ( )
333
+ {
334
+ let ( server, mut client) = setup_server_and_client ( ) ;
335
+ client. api_version_provider =
336
+ APIVersionProvider :: new_with_default_version ( Version :: parse ( "1.2.9" ) . unwrap ( ) ) ;
337
+ client. additional_headers = {
338
+ let mut headers = HeaderMap :: new ( ) ;
339
+ headers. insert ( MITHRIL_API_VERSION_HEADER , "9.4.5" . parse ( ) . unwrap ( ) ) ;
340
+ headers. insert ( "foo" , "bar" . parse ( ) . unwrap ( ) ) ;
341
+ headers
342
+ } ;
343
+
344
+ server. mock ( |when, then| {
345
+ when. method ( httpmock:: Method :: POST )
346
+ . header ( MITHRIL_API_VERSION_HEADER , "1.2.9" )
347
+ . header ( "foo" , "bar" ) ;
348
+ then. status ( 201 ) . body ( r#"{"foo": "a", "bar": 1}"# ) ;
349
+ } ) ;
350
+
351
+ client
352
+ . send ( TestPostQuery {
353
+ body : TestBody :: new ( "miaouss" , 3 ) ,
354
+ } )
355
+ . await
356
+ . expect ( "should not fail" ) ;
357
+ }
358
+
308
359
#[ tokio:: test]
309
360
async fn test_post_query_timeout ( ) {
310
361
let ( server, mut client) = setup_server_and_client ( ) ;
0 commit comments