Skip to content

Commit c625805

Browse files
authored
Merge pull request #641 from input-output-hk/greg/633/api_version
check API version
2 parents e6c1527 + b1504ae commit c625805

File tree

4 files changed

+295
-14
lines changed

4 files changed

+295
-14
lines changed

mithril-aggregator/src/http_server/routes/router.rs

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@ use crate::DependencyManager;
77
use mithril_common::MITHRIL_API_VERSION;
88

99
use reqwest::header::{HeaderMap, HeaderValue};
10+
use reqwest::StatusCode;
1011
use std::sync::Arc;
1112
use warp::http::Method;
12-
use warp::Filter;
13+
use warp::reject::Reject;
14+
use warp::{Filter, Rejection, Reply};
15+
16+
#[derive(Debug)]
17+
pub struct VersionMismatchError;
18+
19+
impl Reject for VersionMismatchError {}
1320

1421
/// Routes
1522
pub fn routes(
@@ -24,14 +31,75 @@ pub fn routes(
2431
"mithril-api-version",
2532
HeaderValue::from_static(MITHRIL_API_VERSION),
2633
);
34+
warp::any()
35+
.and(header_must_be())
36+
.and(warp::path(SERVER_BASE_PATH))
37+
.and(
38+
certificate_routes::routes(dependency_manager.clone())
39+
.or(snapshot_routes::routes(dependency_manager.clone()))
40+
.or(signer_routes::routes(dependency_manager.clone()))
41+
.or(signatures_routes::routes(dependency_manager.clone()))
42+
.or(epoch_routes::routes(dependency_manager))
43+
.with(cors),
44+
)
45+
.recover(handle_custom)
46+
.with(warp::reply::with::headers(headers))
47+
}
48+
49+
/// API Version verification
50+
fn header_must_be() -> impl Filter<Extract = (), Error = Rejection> + Copy {
51+
warp::header::optional("mithril-api-version")
52+
.and_then(|maybe_header: Option<String>| async move {
53+
match maybe_header {
54+
None => Ok(()),
55+
Some(version) if version == MITHRIL_API_VERSION => Ok(()),
56+
Some(_version) => Err(warp::reject::custom(VersionMismatchError)),
57+
}
58+
})
59+
.untuple_one()
60+
}
61+
62+
pub async fn handle_custom(reject: Rejection) -> Result<impl Reply, Rejection> {
63+
if reject.find::<VersionMismatchError>().is_some() {
64+
Ok(StatusCode::PRECONDITION_FAILED)
65+
} else {
66+
Err(reject)
67+
}
68+
}
69+
70+
#[cfg(test)]
71+
mod tests {
72+
use super::*;
73+
74+
#[tokio::test]
75+
async fn test_no_version() {
76+
let filters = header_must_be();
77+
warp::test::request()
78+
.path("/aggregator/whatever")
79+
.filter(&filters)
80+
.await
81+
.unwrap();
82+
}
83+
84+
#[tokio::test]
85+
async fn test_bad_version() {
86+
let filters = header_must_be();
87+
warp::test::request()
88+
.header("mithril-api-version", "0.0.999")
89+
.path("/aggregator/whatever")
90+
.filter(&filters)
91+
.await
92+
.unwrap_err();
93+
}
2794

28-
warp::any().and(warp::path(SERVER_BASE_PATH)).and(
29-
certificate_routes::routes(dependency_manager.clone())
30-
.or(snapshot_routes::routes(dependency_manager.clone()))
31-
.or(signer_routes::routes(dependency_manager.clone()))
32-
.or(signatures_routes::routes(dependency_manager.clone()))
33-
.or(epoch_routes::routes(dependency_manager))
34-
.with(cors)
35-
.with(warp::reply::with::headers(headers)),
36-
)
95+
#[tokio::test]
96+
async fn test_good_version() {
97+
let filters = header_must_be();
98+
warp::test::request()
99+
.header("mithril-api-version", MITHRIL_API_VERSION)
100+
.path("/aggregator/whatever")
101+
.filter(&filters)
102+
.await
103+
.unwrap();
104+
}
37105
}

mithril-client/src/aggregator.rs

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use async_trait::async_trait;
22
use flate2::read::GzDecoder;
33
use futures::StreamExt;
4-
use reqwest::{self, StatusCode};
4+
use reqwest::{self, Response, StatusCode};
55
use reqwest::{Client, RequestBuilder};
66
use slog_scope::debug;
77
use std::env;
@@ -48,6 +48,17 @@ pub enum AggregatorHandlerError {
4848
/// [AggregatorHandler::download_snapshot] beforehand.
4949
#[error("archive not found, did you download it beforehand ? Expected path: '{0}'")]
5050
ArchiveNotFound(PathBuf),
51+
52+
/// Error raised when the server API version mismatch the client API version.
53+
#[error("API version mismatch: {0}")]
54+
ApiVersionMismatch(String),
55+
}
56+
57+
#[cfg(test)]
58+
impl AggregatorHandlerError {
59+
pub fn is_api_version_mismatch(&self) -> bool {
60+
matches!(self, Self::ApiVersionMismatch(_))
61+
}
5162
}
5263

5364
/// AggregatorHandler represents a read interactor with an aggregator
@@ -128,6 +139,22 @@ impl AggregatorHTTPClient {
128139
)),
129140
}
130141
}
142+
143+
/// API version error handling
144+
fn handle_api_error(&self, response: &Response) -> AggregatorHandlerError {
145+
if let Some(version) = response.headers().get("mithril-api-version") {
146+
AggregatorHandlerError::ApiVersionMismatch(format!(
147+
"server version: '{}', signer version: '{}'",
148+
version.to_str().unwrap(),
149+
MITHRIL_API_VERSION
150+
))
151+
} else {
152+
AggregatorHandlerError::ApiVersionMismatch(format!(
153+
"version precondition failed, sent version '{}'.",
154+
MITHRIL_API_VERSION
155+
))
156+
}
157+
}
131158
}
132159

133160
#[async_trait]
@@ -147,6 +174,7 @@ impl AggregatorHandler for AggregatorHTTPClient {
147174
Ok(snapshots) => Ok(snapshots),
148175
Err(err) => Err(AggregatorHandlerError::JsonParseFailed(err.to_string())),
149176
},
177+
StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
150178
status_error => Err(AggregatorHandlerError::RemoteServerTechnical(
151179
status_error.to_string(),
152180
)),
@@ -172,6 +200,7 @@ impl AggregatorHandler for AggregatorHTTPClient {
172200
Ok(snapshot) => Ok(snapshot),
173201
Err(err) => Err(AggregatorHandlerError::JsonParseFailed(err.to_string())),
174202
},
203+
StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
175204
StatusCode::NOT_FOUND => Err(AggregatorHandlerError::RemoteServerLogical(
176205
"snapshot not found".to_string(),
177206
)),
@@ -226,6 +255,7 @@ impl AggregatorHandler for AggregatorHTTPClient {
226255
}
227256
Ok(local_path.into_os_string().into_string().unwrap())
228257
}
258+
StatusCode::PRECONDITION_FAILED => Err(self.handle_api_error(&response)),
229259
StatusCode::NOT_FOUND => Err(AggregatorHandlerError::RemoteServerLogical(
230260
"snapshot archive not found".to_string(),
231261
)),
@@ -352,6 +382,20 @@ mod tests {
352382
assert_eq!(snapshots.unwrap(), snapshots_expected);
353383
}
354384

385+
#[tokio::test]
386+
async fn test_list_snapshots_ko_412() {
387+
let (server, config) = setup_test();
388+
let _snapshots_mock = server.mock(|when, then| {
389+
when.path("/snapshots");
390+
then.status(412).header("mithril-api-version", "0.0.999");
391+
});
392+
let aggregator_client =
393+
AggregatorHTTPClient::new(config.network, config.aggregator_endpoint);
394+
let error = aggregator_client.list_snapshots().await.unwrap_err();
395+
396+
assert!(error.is_api_version_mismatch());
397+
}
398+
355399
#[tokio::test]
356400
async fn test_list_snapshots_ko_500() {
357401
let (server, config) = setup_test();
@@ -403,6 +447,24 @@ mod tests {
403447
assert!(snapshot.is_err());
404448
}
405449

450+
#[tokio::test]
451+
async fn test_snapshot_details_ko_412() {
452+
let (server, config) = setup_test();
453+
let digest = "digest123";
454+
let _snapshots_mock = server.mock(|when, then| {
455+
when.path(format!("/snapshot/{}", digest));
456+
then.status(412).header("mithril-api-version", "0.0.999");
457+
});
458+
let aggregator_client =
459+
AggregatorHTTPClient::new(config.network, config.aggregator_endpoint);
460+
let error = aggregator_client
461+
.get_snapshot_details(digest)
462+
.await
463+
.unwrap_err();
464+
465+
assert!(error.is_api_version_mismatch());
466+
}
467+
406468
#[tokio::test]
407469
async fn get_snapshot_details_ko_500() {
408470
let digest = "digest123";
@@ -448,6 +510,26 @@ mod tests {
448510
assert_eq!(data_downloaded, data_expected);
449511
}
450512

513+
#[tokio::test]
514+
async fn test_download_snapshot_ko_412() {
515+
let (server, config) = setup_test();
516+
let digest = "digest123";
517+
let url_path = "/download";
518+
let _snapshots_mock = server.mock(|when, then| {
519+
when.path(url_path.to_string());
520+
then.status(412).header("mithril-api-version", "0.0.999");
521+
});
522+
let aggregator_client =
523+
AggregatorHTTPClient::new(config.network, config.aggregator_endpoint);
524+
let location = server.url(url_path);
525+
let error = aggregator_client
526+
.download_snapshot(digest, &location)
527+
.await
528+
.unwrap_err();
529+
530+
assert!(error.is_api_version_mismatch());
531+
}
532+
451533
#[tokio::test]
452534
async fn get_download_snapshot_ko_unreachable() {
453535
let digest = "digest123";
@@ -502,6 +584,22 @@ mod tests {
502584
assert!(local_dir_path.is_err());
503585
}
504586

587+
#[tokio::test]
588+
async fn test_certificate_details_412() {
589+
let (server, config) = setup_test();
590+
let certificate_hash = "certificate-hash-123";
591+
let _snapshots_mock = server.mock(|when, then| {
592+
when.path(format!("/certificate/{}", certificate_hash));
593+
then.status(412).header("mithril-api-version", "0.0.999");
594+
});
595+
let aggregator_client =
596+
AggregatorHTTPClient::new(config.network, config.aggregator_endpoint);
597+
let _error = aggregator_client
598+
.get_certificate_details(certificate_hash)
599+
.await
600+
.unwrap_err();
601+
}
602+
505603
#[tokio::test]
506604
async fn get_certificate_details_ok() {
507605
let certificate_hash = "certificate-hash-123";

0 commit comments

Comments
 (0)