Skip to content

Commit a402fdb

Browse files
committed
feat: prototype aggregator client query system
1 parent 0554dbb commit a402fdb

File tree

5 files changed

+319
-1
lines changed

5 files changed

+319
-1
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
use anyhow::{Context, anyhow};
2+
use reqwest::{Response, Url};
3+
4+
use serde::de::DeserializeOwned;
5+
use slog::Logger;
6+
7+
use crate::AggregatorClientResult;
8+
use crate::error::AggregatorClientError;
9+
10+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11+
pub(crate) enum QueryMethod {
12+
Get,
13+
Post,
14+
}
15+
16+
// Todo: wasm compatibility
17+
#[async_trait::async_trait]
18+
pub(crate) trait AggregatorQuery {
19+
type Response: DeserializeOwned;
20+
type Body: serde::Serialize + Sized;
21+
22+
fn method() -> QueryMethod;
23+
24+
fn route(&self) -> String;
25+
26+
fn body(&self) -> Option<Self::Body> {
27+
None
28+
}
29+
30+
async fn handle_response(
31+
&self,
32+
context: QueryContext,
33+
) -> AggregatorClientResult<Self::Response>;
34+
}
35+
36+
// internal to the crate
37+
pub(crate) struct QueryContext {
38+
pub(crate) response: Response,
39+
pub(crate) logger: Logger,
40+
}
41+
42+
impl QueryContext {
43+
pub(crate) async fn unhandled_status_code(self) -> AggregatorClientError {
44+
AggregatorClientError::from_response(self.response).await
45+
}
46+
}
47+
48+
pub struct AggregatorClient {
49+
aggregator_endpoint: Url,
50+
client: reqwest::Client,
51+
logger: Logger,
52+
}
53+
54+
impl AggregatorClient {
55+
pub async fn send<Q: AggregatorQuery>(&self, query: Q) -> AggregatorClientResult<Q::Response> {
56+
let mut request_builder = match Q::method() {
57+
QueryMethod::Get => self.client.get(self.join_aggregator_endpoint(&query.route())?),
58+
QueryMethod::Post => self.client.post(self.join_aggregator_endpoint(&query.route())?),
59+
};
60+
61+
if let Some(body) = query.body() {
62+
request_builder = request_builder.json(&body);
63+
}
64+
65+
match request_builder.send().await {
66+
Ok(response) => {
67+
// should we always warn?
68+
if !response.status().is_server_error() {
69+
// todo: import code
70+
// self.warn_if_api_version_mismatch(&response);
71+
}
72+
73+
let context = QueryContext {
74+
response,
75+
logger: self.logger.clone(),
76+
};
77+
query.handle_response(context).await
78+
}
79+
Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))),
80+
}
81+
}
82+
83+
fn join_aggregator_endpoint(&self, endpoint: &str) -> AggregatorClientResult<Url> {
84+
self.aggregator_endpoint
85+
.join(endpoint)
86+
.with_context(|| {
87+
format!(
88+
"Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'",
89+
self.aggregator_endpoint
90+
)
91+
})
92+
.map_err(AggregatorClientError::InvalidEndpoint)
93+
}
94+
}
95+
96+
#[cfg(test)]
97+
mod tests {
98+
use http::StatusCode;
99+
use httpmock::MockServer;
100+
101+
use crate::test::TestLogger;
102+
103+
use super::*;
104+
105+
#[derive(Debug, Eq, PartialEq, serde::Deserialize)]
106+
struct TestResponse {
107+
foo: String,
108+
bar: i32,
109+
}
110+
111+
struct TestGetQuery;
112+
113+
#[async_trait::async_trait]
114+
impl AggregatorQuery for TestGetQuery {
115+
type Response = TestResponse;
116+
type Body = ();
117+
118+
fn method() -> QueryMethod {
119+
QueryMethod::Get
120+
}
121+
122+
fn route(&self) -> String {
123+
"/dummy-route".to_string()
124+
}
125+
126+
async fn handle_response(
127+
&self,
128+
context: QueryContext,
129+
) -> AggregatorClientResult<Self::Response> {
130+
match context.response.status() {
131+
StatusCode::OK => context
132+
.response
133+
.json::<TestResponse>()
134+
.await
135+
.map_err(|err| AggregatorClientError::JsonParseFailed(anyhow!(err))),
136+
_ => Err(context.unhandled_status_code().await),
137+
}
138+
}
139+
}
140+
141+
#[tokio::test]
142+
async fn test_minimal_query() {
143+
let server = MockServer::start();
144+
server.mock(|when, then| {
145+
when.method(httpmock::Method::GET).path(TestGetQuery.route());
146+
then.status(200).body(r#"{"foo": "bar", "bar": 123}"#);
147+
});
148+
149+
let aggregator_endpoint = Url::parse(&server.url("/")).unwrap();
150+
let client = AggregatorClient {
151+
aggregator_endpoint,
152+
client: reqwest::Client::new(),
153+
logger: TestLogger::stdout(),
154+
};
155+
156+
let response = client.send(TestGetQuery).await.unwrap();
157+
158+
assert_eq!(
159+
response,
160+
TestResponse {
161+
foo: "bar".to_string(),
162+
bar: 123,
163+
}
164+
)
165+
}
166+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
use anyhow::anyhow;
2+
use reqwest::{Response, StatusCode, header};
3+
use thiserror::Error;
4+
5+
use mithril_common::StdError;
6+
use mithril_common::entities::{ClientError, ServerError};
7+
8+
use crate::JSON_CONTENT_TYPE;
9+
10+
/// Error structure for the Aggregator Client.
11+
#[derive(Error, Debug)]
12+
pub enum AggregatorClientError {
13+
/// The aggregator host has returned a technical error.
14+
#[error("remote server technical error")]
15+
RemoteServerTechnical(#[source] StdError),
16+
17+
/// The aggregator host responded it cannot fulfill our request.
18+
#[error("remote server logical error")]
19+
RemoteServerLogical(#[source] StdError),
20+
21+
/// Could not reach aggregator.
22+
#[error("Remote server unreachable")]
23+
RemoteServerUnreachable(#[source] StdError),
24+
25+
/// Unhandled status code
26+
#[error("Unhandled status code: {0}, response text: {1}")]
27+
UnhandledStatusCode(StatusCode, String),
28+
29+
/// Could not parse response.
30+
#[error("Json parsing failed")]
31+
JsonParseFailed(#[source] StdError),
32+
33+
/// Failed to join the query endpoint to the aggregator url
34+
#[error("Invalid endpoint")]
35+
InvalidEndpoint(#[source] StdError),
36+
}
37+
38+
impl AggregatorClientError {
39+
/// Create an `AggregatorClientError` from a response.
40+
///
41+
/// This method is meant to be used after handling domain-specific cases leaving only
42+
/// 4xx or 5xx status codes.
43+
/// Otherwise, it will return an `UnhandledStatusCode` error.
44+
pub async fn from_response(response: Response) -> Self {
45+
let error_code = response.status();
46+
47+
if error_code.is_client_error() {
48+
let root_cause = Self::get_root_cause(response).await;
49+
Self::RemoteServerLogical(anyhow!(root_cause))
50+
} else if error_code.is_server_error() {
51+
let root_cause = Self::get_root_cause(response).await;
52+
Self::RemoteServerTechnical(anyhow!(root_cause))
53+
} else {
54+
let response_text = response.text().await.unwrap_or_default();
55+
Self::UnhandledStatusCode(error_code, response_text)
56+
}
57+
}
58+
59+
async fn get_root_cause(response: Response) -> String {
60+
let error_code = response.status();
61+
let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase();
62+
let is_json = response
63+
.headers()
64+
.get(header::CONTENT_TYPE)
65+
.is_some_and(|ct| JSON_CONTENT_TYPE == ct);
66+
67+
if is_json {
68+
let json_value: serde_json::Value = response.json().await.unwrap_or_default();
69+
70+
if let Ok(client_error) = serde_json::from_value::<ClientError>(json_value.clone()) {
71+
format!(
72+
"{}: {}: {}",
73+
canonical_reason, client_error.label, client_error.message
74+
)
75+
} else if let Ok(server_error) =
76+
serde_json::from_value::<ServerError>(json_value.clone())
77+
{
78+
format!("{}: {}", canonical_reason, server_error.message)
79+
} else if json_value.is_null() {
80+
canonical_reason.to_string()
81+
} else {
82+
format!("{canonical_reason}: {json_value}")
83+
}
84+
} else {
85+
let response_text = response.text().await.unwrap_or_default();
86+
format!("{canonical_reason}: {response_text}")
87+
}
88+
}
89+
}

internal/mithril-aggregator-client/src/lib.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,16 @@
22
//! This crate provides a client to request data from a Mithril Aggregator.
33
//!
44
5+
mod client;
6+
mod error;
7+
pub mod query;
58
#[cfg(test)]
6-
mod test;
9+
mod test;
10+
11+
pub use client::AggregatorClient;
12+
pub use error::AggregatorClientError;
13+
14+
pub(crate) const JSON_CONTENT_TYPE: reqwest::header::HeaderValue =
15+
reqwest::header::HeaderValue::from_static("application/json");
16+
17+
pub type AggregatorClientResult<T> = Result<T, error::AggregatorClientError>;
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use anyhow::anyhow;
2+
use reqwest::StatusCode;
3+
use slog::debug;
4+
5+
use mithril_common::messages::CertificateMessage;
6+
7+
use crate::AggregatorClientResult;
8+
use crate::client::{AggregatorQuery, QueryContext, QueryMethod};
9+
use crate::error::AggregatorClientError;
10+
11+
pub struct CertificateDetailsQuery {
12+
hash: String,
13+
}
14+
15+
impl CertificateDetailsQuery {
16+
pub fn new(hash: String) -> Self {
17+
Self { hash }
18+
}
19+
}
20+
21+
#[async_trait::async_trait]
22+
impl AggregatorQuery for CertificateDetailsQuery {
23+
type Response = Option<CertificateMessage>;
24+
type Body = ();
25+
26+
fn method() -> QueryMethod {
27+
QueryMethod::Get
28+
}
29+
30+
fn route(&self) -> String {
31+
format!("certificate/{}", self.hash)
32+
}
33+
34+
async fn handle_response(
35+
&self,
36+
context: QueryContext,
37+
) -> AggregatorClientResult<Self::Response> {
38+
debug!(context.logger, "Retrieve certificate details"; "certificate_hash" => %self.hash);
39+
40+
match context.response.status() {
41+
StatusCode::OK => match context.response.json::<CertificateMessage>().await {
42+
Ok(message) => Ok(Some(message)),
43+
Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))),
44+
},
45+
StatusCode::NOT_FOUND => Ok(None),
46+
_ => Err(context.unhandled_status_code().await),
47+
}
48+
}
49+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod certificate_details;
2+
3+
pub use certificate_details::*;

0 commit comments

Comments
 (0)