@@ -12,6 +12,7 @@ use std::sync::Arc;
12
12
use anyhow:: { anyhow, Context } ;
13
13
use async_recursion:: async_recursion;
14
14
use async_trait:: async_trait;
15
+ use reqwest:: header:: HeaderMap ;
15
16
use reqwest:: { Response , StatusCode , Url } ;
16
17
use semver:: Version ;
17
18
use slog:: { debug, Logger } ;
@@ -254,7 +255,7 @@ impl AggregatorHTTPClient {
254
255
return self . get ( url) . await ;
255
256
}
256
257
257
- Err ( self . handle_api_error ( & response) . await )
258
+ Err ( self . handle_api_error ( response. headers ( ) ) . await )
258
259
}
259
260
StatusCode :: NOT_FOUND => Err ( AggregatorClientError :: RemoteServerLogical ( anyhow ! (
260
261
"Url='{url}' not found"
@@ -298,7 +299,7 @@ impl AggregatorHTTPClient {
298
299
return self . post ( url, json) . await ;
299
300
}
300
301
301
- Err ( self . handle_api_error ( & response) . await )
302
+ Err ( self . handle_api_error ( response. headers ( ) ) . await )
302
303
}
303
304
StatusCode :: NOT_FOUND => Err ( AggregatorClientError :: RemoteServerLogical ( anyhow ! (
304
305
"Url='{url} not found"
@@ -310,9 +311,21 @@ impl AggregatorHTTPClient {
310
311
}
311
312
}
312
313
314
+ fn get_url_for_route ( & self , endpoint : & str ) -> Result < Url , AggregatorClientError > {
315
+ self . aggregator_endpoint
316
+ . join ( endpoint)
317
+ . with_context ( || {
318
+ format ! (
319
+ "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'" ,
320
+ self . aggregator_endpoint
321
+ )
322
+ } )
323
+ . map_err ( AggregatorClientError :: SubsystemError )
324
+ }
325
+
313
326
/// API version error handling
314
- async fn handle_api_error ( & self , response : & Response ) -> AggregatorClientError {
315
- if let Some ( version) = response . headers ( ) . get ( MITHRIL_API_VERSION_HEADER ) {
327
+ async fn handle_api_error ( & self , response_header : & HeaderMap ) -> AggregatorClientError {
328
+ if let Some ( version) = response_header . get ( MITHRIL_API_VERSION_HEADER ) {
316
329
AggregatorClientError :: ApiVersionMismatch ( anyhow ! (
317
330
"server version: '{}', signer version: '{}'" ,
318
331
version. to_str( ) . unwrap( ) ,
@@ -326,18 +339,6 @@ impl AggregatorHTTPClient {
326
339
}
327
340
}
328
341
329
- fn get_url_for_route ( & self , endpoint : & str ) -> Result < Url , AggregatorClientError > {
330
- self . aggregator_endpoint
331
- . join ( endpoint)
332
- . with_context ( || {
333
- format ! (
334
- "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'" ,
335
- self . aggregator_endpoint
336
- )
337
- } )
338
- . map_err ( AggregatorClientError :: SubsystemError )
339
- }
340
-
341
342
async fn remote_logical_error ( response : Response ) -> AggregatorClientError {
342
343
let status_code = response. status ( ) ;
343
344
let client_error = response
@@ -402,6 +403,7 @@ impl AggregatorClient for AggregatorHTTPClient {
402
403
#[ cfg( test) ]
403
404
mod tests {
404
405
use httpmock:: MockServer ;
406
+ use reqwest:: header:: { HeaderName , HeaderValue } ;
405
407
406
408
use mithril_common:: api_version:: APIVersionProvider ;
407
409
use mithril_common:: entities:: { ClientError , ServerError } ;
@@ -414,17 +416,31 @@ mod tests {
414
416
} ;
415
417
}
416
418
419
+ fn setup_client ( server_url : & str , api_versions : Vec < Version > ) -> AggregatorHTTPClient {
420
+ AggregatorHTTPClient :: new (
421
+ Url :: parse ( server_url) . unwrap ( ) ,
422
+ api_versions,
423
+ crate :: test_utils:: test_logger ( ) ,
424
+ )
425
+ . expect ( "building aggregator http client should not fail" )
426
+ }
427
+
417
428
fn setup_server_and_client ( ) -> ( MockServer , AggregatorHTTPClient ) {
418
429
let server = MockServer :: start ( ) ;
419
- let client = AggregatorHTTPClient :: new (
420
- Url :: parse ( & server. url ( "" ) ) . unwrap ( ) ,
430
+ let client = setup_client (
431
+ & server. url ( "" ) ,
421
432
APIVersionProvider :: compute_all_versions_sorted ( ) . unwrap ( ) ,
422
- crate :: test_utils:: test_logger ( ) ,
423
- )
424
- . expect ( "building aggregator http client should not fail" ) ;
433
+ ) ;
425
434
( server, client)
426
435
}
427
436
437
+ fn mithril_api_version_headers ( version : & str ) -> HeaderMap {
438
+ HeaderMap :: from_iter ( [ (
439
+ HeaderName :: from_static ( MITHRIL_API_VERSION_HEADER ) ,
440
+ HeaderValue :: from_str ( version) . unwrap ( ) ,
441
+ ) ] )
442
+ }
443
+
428
444
#[ test]
429
445
fn always_append_trailing_slash_at_build ( ) {
430
446
for ( expected, url) in [
@@ -579,4 +595,93 @@ mod tests {
579
595
. unwrap_err ( ) ;
580
596
assert_error_eq ! ( post_content_error, expected_error) ;
581
597
}
598
+
599
+ #[ tokio:: test]
600
+ async fn test_client_handle_412_api_version_mismatch_with_version_in_response_header ( ) {
601
+ let version = "0.0.0" ;
602
+
603
+ let ( aggregator, client) = setup_server_and_client ( ) ;
604
+ aggregator. mock ( |_when, then| {
605
+ then. status ( StatusCode :: PRECONDITION_FAILED . as_u16 ( ) )
606
+ . header ( MITHRIL_API_VERSION_HEADER , version) ;
607
+ } ) ;
608
+
609
+ let expected_error = client
610
+ . handle_api_error ( & mithril_api_version_headers ( version) )
611
+ . await ;
612
+
613
+ let get_content_error = client
614
+ . get_content ( AggregatorRequest :: ListCertificates )
615
+ . await
616
+ . unwrap_err ( ) ;
617
+ assert_error_eq ! ( get_content_error, expected_error) ;
618
+
619
+ let post_content_error = client
620
+ . post_content ( AggregatorRequest :: ListCertificates )
621
+ . await
622
+ . unwrap_err ( ) ;
623
+ assert_error_eq ! ( post_content_error, expected_error) ;
624
+ }
625
+
626
+ #[ tokio:: test]
627
+ async fn test_client_handle_412_api_version_mismatch_without_version_in_response_header ( ) {
628
+ let ( aggregator, client) = setup_server_and_client ( ) ;
629
+ aggregator. mock ( |_when, then| {
630
+ then. status ( StatusCode :: PRECONDITION_FAILED . as_u16 ( ) ) ;
631
+ } ) ;
632
+
633
+ let expected_error = client. handle_api_error ( & HeaderMap :: new ( ) ) . await ;
634
+
635
+ let get_content_error = client
636
+ . get_content ( AggregatorRequest :: ListCertificates )
637
+ . await
638
+ . unwrap_err ( ) ;
639
+ assert_error_eq ! ( get_content_error, expected_error) ;
640
+
641
+ let post_content_error = client
642
+ . post_content ( AggregatorRequest :: ListCertificates )
643
+ . await
644
+ . unwrap_err ( ) ;
645
+ assert_error_eq ! ( post_content_error, expected_error) ;
646
+ }
647
+
648
+ #[ tokio:: test]
649
+ async fn test_client_can_fallback_to_a_second_version_when_412_api_version_mistmatch ( ) {
650
+ let bad_version = "0.0.0" ;
651
+ let good_version = "1.0.0" ;
652
+
653
+ let aggregator = MockServer :: start ( ) ;
654
+ let client = setup_client (
655
+ & aggregator. url ( "" ) ,
656
+ vec ! [
657
+ Version :: parse( bad_version) . unwrap( ) ,
658
+ Version :: parse( good_version) . unwrap( ) ,
659
+ ] ,
660
+ ) ;
661
+ aggregator. mock ( |when, then| {
662
+ when. header ( MITHRIL_API_VERSION_HEADER , bad_version) ;
663
+ then. status ( StatusCode :: PRECONDITION_FAILED . as_u16 ( ) )
664
+ . header ( MITHRIL_API_VERSION_HEADER , bad_version) ;
665
+ } ) ;
666
+ aggregator. mock ( |when, then| {
667
+ when. header ( MITHRIL_API_VERSION_HEADER , good_version) ;
668
+ then. status ( StatusCode :: OK . as_u16 ( ) ) ;
669
+ } ) ;
670
+
671
+ assert_eq ! (
672
+ client. compute_current_api_version( ) . await ,
673
+ Some ( Version :: parse( bad_version) . unwrap( ) ) ,
674
+ "Bad version should be tried first"
675
+ ) ;
676
+
677
+ client
678
+ . get_content ( AggregatorRequest :: ListCertificates )
679
+ . await
680
+ . expect ( "should have run with a fallback version" ) ;
681
+
682
+ client
683
+ . post_content ( AggregatorRequest :: ListCertificates )
684
+ . await
685
+ . expect ( "should have run with a fallback version" ) ;
686
+ }
582
687
}
0 commit comments