Skip to content

Commit 2cdae3a

Browse files
committed
Add tests for 412 error in mithril-client aggregator client
1 parent 744ea0c commit 2cdae3a

File tree

1 file changed

+126
-21
lines changed

1 file changed

+126
-21
lines changed

mithril-client/src/aggregator_client.rs

Lines changed: 126 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::sync::Arc;
1212
use anyhow::{anyhow, Context};
1313
use async_recursion::async_recursion;
1414
use async_trait::async_trait;
15+
use reqwest::header::HeaderMap;
1516
use reqwest::{Response, StatusCode, Url};
1617
use semver::Version;
1718
use slog::{debug, Logger};
@@ -254,7 +255,7 @@ impl AggregatorHTTPClient {
254255
return self.get(url).await;
255256
}
256257

257-
Err(self.handle_api_error(&response).await)
258+
Err(self.handle_api_error(response.headers()).await)
258259
}
259260
StatusCode::NOT_FOUND => Err(AggregatorClientError::RemoteServerLogical(anyhow!(
260261
"Url='{url}' not found"
@@ -298,7 +299,7 @@ impl AggregatorHTTPClient {
298299
return self.post(url, json).await;
299300
}
300301

301-
Err(self.handle_api_error(&response).await)
302+
Err(self.handle_api_error(response.headers()).await)
302303
}
303304
StatusCode::NOT_FOUND => Err(AggregatorClientError::RemoteServerLogical(anyhow!(
304305
"Url='{url} not found"
@@ -310,9 +311,21 @@ impl AggregatorHTTPClient {
310311
}
311312
}
312313

314+
fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
315+
self.aggregator_endpoint
316+
.join(endpoint)
317+
.with_context(|| {
318+
format!(
319+
"Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
320+
self.aggregator_endpoint
321+
)
322+
})
323+
.map_err(AggregatorClientError::SubsystemError)
324+
}
325+
313326
/// API version error handling
314-
async fn handle_api_error(&self, response: &Response) -> AggregatorClientError {
315-
if let Some(version) = response.headers().get(MITHRIL_API_VERSION_HEADER) {
327+
async fn handle_api_error(&self, response_header: &HeaderMap) -> AggregatorClientError {
328+
if let Some(version) = response_header.get(MITHRIL_API_VERSION_HEADER) {
316329
AggregatorClientError::ApiVersionMismatch(anyhow!(
317330
"server version: '{}', signer version: '{}'",
318331
version.to_str().unwrap(),
@@ -326,18 +339,6 @@ impl AggregatorHTTPClient {
326339
}
327340
}
328341

329-
fn get_url_for_route(&self, endpoint: &str) -> Result<Url, AggregatorClientError> {
330-
self.aggregator_endpoint
331-
.join(endpoint)
332-
.with_context(|| {
333-
format!(
334-
"Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
335-
self.aggregator_endpoint
336-
)
337-
})
338-
.map_err(AggregatorClientError::SubsystemError)
339-
}
340-
341342
async fn remote_logical_error(response: Response) -> AggregatorClientError {
342343
let status_code = response.status();
343344
let client_error = response
@@ -402,6 +403,7 @@ impl AggregatorClient for AggregatorHTTPClient {
402403
#[cfg(test)]
403404
mod tests {
404405
use httpmock::MockServer;
406+
use reqwest::header::{HeaderName, HeaderValue};
405407

406408
use mithril_common::api_version::APIVersionProvider;
407409
use mithril_common::entities::{ClientError, ServerError};
@@ -414,17 +416,31 @@ mod tests {
414416
};
415417
}
416418

419+
fn setup_client(server_url: &str, api_versions: Vec<Version>) -> AggregatorHTTPClient {
420+
AggregatorHTTPClient::new(
421+
Url::parse(server_url).unwrap(),
422+
api_versions,
423+
crate::test_utils::test_logger(),
424+
)
425+
.expect("building aggregator http client should not fail")
426+
}
427+
417428
fn setup_server_and_client() -> (MockServer, AggregatorHTTPClient) {
418429
let server = MockServer::start();
419-
let client = AggregatorHTTPClient::new(
420-
Url::parse(&server.url("")).unwrap(),
430+
let client = setup_client(
431+
&server.url(""),
421432
APIVersionProvider::compute_all_versions_sorted().unwrap(),
422-
crate::test_utils::test_logger(),
423-
)
424-
.expect("building aggregator http client should not fail");
433+
);
425434
(server, client)
426435
}
427436

437+
fn mithril_api_version_headers(version: &str) -> HeaderMap {
438+
HeaderMap::from_iter([(
439+
HeaderName::from_static(MITHRIL_API_VERSION_HEADER),
440+
HeaderValue::from_str(version).unwrap(),
441+
)])
442+
}
443+
428444
#[test]
429445
fn always_append_trailing_slash_at_build() {
430446
for (expected, url) in [
@@ -579,4 +595,93 @@ mod tests {
579595
.unwrap_err();
580596
assert_error_eq!(post_content_error, expected_error);
581597
}
598+
599+
#[tokio::test]
600+
async fn test_client_handle_412_api_version_mismatch_with_version_in_response_header() {
601+
let version = "0.0.0";
602+
603+
let (aggregator, client) = setup_server_and_client();
604+
aggregator.mock(|_when, then| {
605+
then.status(StatusCode::PRECONDITION_FAILED.as_u16())
606+
.header(MITHRIL_API_VERSION_HEADER, version);
607+
});
608+
609+
let expected_error = client
610+
.handle_api_error(&mithril_api_version_headers(version))
611+
.await;
612+
613+
let get_content_error = client
614+
.get_content(AggregatorRequest::ListCertificates)
615+
.await
616+
.unwrap_err();
617+
assert_error_eq!(get_content_error, expected_error);
618+
619+
let post_content_error = client
620+
.post_content(AggregatorRequest::ListCertificates)
621+
.await
622+
.unwrap_err();
623+
assert_error_eq!(post_content_error, expected_error);
624+
}
625+
626+
#[tokio::test]
627+
async fn test_client_handle_412_api_version_mismatch_without_version_in_response_header() {
628+
let (aggregator, client) = setup_server_and_client();
629+
aggregator.mock(|_when, then| {
630+
then.status(StatusCode::PRECONDITION_FAILED.as_u16());
631+
});
632+
633+
let expected_error = client.handle_api_error(&HeaderMap::new()).await;
634+
635+
let get_content_error = client
636+
.get_content(AggregatorRequest::ListCertificates)
637+
.await
638+
.unwrap_err();
639+
assert_error_eq!(get_content_error, expected_error);
640+
641+
let post_content_error = client
642+
.post_content(AggregatorRequest::ListCertificates)
643+
.await
644+
.unwrap_err();
645+
assert_error_eq!(post_content_error, expected_error);
646+
}
647+
648+
#[tokio::test]
649+
async fn test_client_can_fallback_to_a_second_version_when_412_api_version_mistmatch() {
650+
let bad_version = "0.0.0";
651+
let good_version = "1.0.0";
652+
653+
let aggregator = MockServer::start();
654+
let client = setup_client(
655+
&aggregator.url(""),
656+
vec![
657+
Version::parse(bad_version).unwrap(),
658+
Version::parse(good_version).unwrap(),
659+
],
660+
);
661+
aggregator.mock(|when, then| {
662+
when.header(MITHRIL_API_VERSION_HEADER, bad_version);
663+
then.status(StatusCode::PRECONDITION_FAILED.as_u16())
664+
.header(MITHRIL_API_VERSION_HEADER, bad_version);
665+
});
666+
aggregator.mock(|when, then| {
667+
when.header(MITHRIL_API_VERSION_HEADER, good_version);
668+
then.status(StatusCode::OK.as_u16());
669+
});
670+
671+
assert_eq!(
672+
client.compute_current_api_version().await,
673+
Some(Version::parse(bad_version).unwrap()),
674+
"Bad version should be tried first"
675+
);
676+
677+
client
678+
.get_content(AggregatorRequest::ListCertificates)
679+
.await
680+
.expect("should have run with a fallback version");
681+
682+
client
683+
.post_content(AggregatorRequest::ListCertificates)
684+
.await
685+
.expect("should have run with a fallback version");
686+
}
582687
}

0 commit comments

Comments
 (0)