Skip to content

Commit 8c5a9ee

Browse files
committed
feat(aggregator-client): add ability to set additional headers
1 parent c8b90d2 commit 8c5a9ee

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

internal/mithril-aggregator-client/src/builder.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use anyhow::Context;
22
use reqwest::{IntoUrl, Url};
33
use slog::{Logger, o};
4+
use std::collections::HashMap;
45
use std::time::Duration;
56

67
use mithril_common::StdResult;
@@ -12,6 +13,7 @@ use crate::client::AggregatorClient;
1213
pub struct AggregatorClientBuilder {
1314
aggregator_url_result: reqwest::Result<Url>,
1415
api_version_provider: Option<APIVersionProvider>,
16+
additional_headers: Option<HashMap<String, String>>,
1517
timeout_duration: Option<Duration>,
1618
logger: Option<Logger>,
1719
}
@@ -24,6 +26,7 @@ impl AggregatorClientBuilder {
2426
Self {
2527
aggregator_url_result: aggregator_url.into_url(),
2628
api_version_provider: None,
29+
additional_headers: None,
2730
timeout_duration: None,
2831
logger: None,
2932
}
@@ -47,6 +50,12 @@ impl AggregatorClientBuilder {
4750
self
4851
}
4952

53+
/// Add a set of http headers that will be sent on client requests
54+
pub fn with_headers(mut self, custom_headers: HashMap<String, String>) -> Self {
55+
self.additional_headers = Some(custom_headers);
56+
self
57+
}
58+
5059
/// Returns an [AggregatorClient] based on the builder configuration
5160
pub fn build(self) -> StdResult<AggregatorClient> {
5261
let aggregator_endpoint =
@@ -55,10 +64,14 @@ impl AggregatorClientBuilder {
5564
)?);
5665
let logger = self.logger.unwrap_or_else(|| Logger::root(slog::Discard, o!()));
5766
let api_version_provider = self.api_version_provider.unwrap_or_default();
67+
let additional_headers = self.additional_headers.unwrap_or_default();
5868

5969
Ok(AggregatorClient {
6070
aggregator_endpoint,
6171
api_version_provider,
72+
additional_headers: (&additional_headers)
73+
.try_into()
74+
.with_context(|| format!("Invalid headers: '{additional_headers:?}'"))?,
6275
timeout_duration: self.timeout_duration,
6376
client: reqwest::Client::new(),
6477
logger,

internal/mithril-aggregator-client/src/client.rs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use anyhow::{Context, anyhow};
2-
use reqwest::{IntoUrl, Response, Url};
2+
use reqwest::{IntoUrl, Response, Url, header::HeaderMap};
33
use semver::Version;
44
use slog::{Logger, error, warn};
55
use std::time::Duration;
@@ -18,6 +18,7 @@ const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incom
1818
pub struct AggregatorClient {
1919
pub(super) aggregator_endpoint: Url,
2020
pub(super) api_version_provider: APIVersionProvider,
21+
pub(super) additional_headers: HeaderMap,
2122
pub(super) timeout_duration: Option<Duration>,
2223
pub(super) client: reqwest::Client,
2324
pub(super) logger: Logger,
@@ -39,6 +40,7 @@ impl AggregatorClient {
3940
QueryMethod::Get => self.client.get(self.join_aggregator_endpoint(&query.route())?),
4041
QueryMethod::Post => self.client.post(self.join_aggregator_endpoint(&query.route())?),
4142
}
43+
.headers(self.additional_headers.clone())
4244
.header(MITHRIL_API_VERSION_HEADER, current_api_version.to_string());
4345

4446
if let Some(body) = query.body() {
@@ -244,6 +246,29 @@ mod tests {
244246
client.send(TestGetQuery).await.expect("should not fail");
245247
}
246248

249+
#[tokio::test]
250+
async fn test_get_query_send_additional_header_and_dont_override_mithril_api_version_header()
251+
{
252+
let (server, mut client) = setup_server_and_client();
253+
client.api_version_provider =
254+
APIVersionProvider::new_with_default_version(Version::parse("1.2.9").unwrap());
255+
client.additional_headers = {
256+
let mut headers = HeaderMap::new();
257+
headers.insert(MITHRIL_API_VERSION_HEADER, "9.4.5".parse().unwrap());
258+
headers.insert("foo", "bar".parse().unwrap());
259+
headers
260+
};
261+
262+
server.mock(|when, then| {
263+
when.method(httpmock::Method::GET)
264+
.header(MITHRIL_API_VERSION_HEADER, "1.2.9")
265+
.header("foo", "bar");
266+
then.status(200).body(r#"{"foo": "a", "bar": 1}"#);
267+
});
268+
269+
client.send(TestGetQuery).await.expect("should not fail");
270+
}
271+
247272
#[tokio::test]
248273
async fn test_get_query_timeout() {
249274
let (server, mut client) = setup_server_and_client();
@@ -276,14 +301,12 @@ mod tests {
276301
then.status(201);
277302
});
278303

279-
let response = client
304+
client
280305
.send(TestPostQuery {
281306
body: TestBody::new("miaouss", 5),
282307
})
283308
.await
284309
.unwrap();
285-
286-
assert_eq!(response, ())
287310
}
288311

289312
#[tokio::test]
@@ -305,6 +328,34 @@ mod tests {
305328
.expect("should not fail");
306329
}
307330

331+
#[tokio::test]
332+
async fn test_post_query_send_additional_header_and_dont_override_mithril_api_version_header()
333+
{
334+
let (server, mut client) = setup_server_and_client();
335+
client.api_version_provider =
336+
APIVersionProvider::new_with_default_version(Version::parse("1.2.9").unwrap());
337+
client.additional_headers = {
338+
let mut headers = HeaderMap::new();
339+
headers.insert(MITHRIL_API_VERSION_HEADER, "9.4.5".parse().unwrap());
340+
headers.insert("foo", "bar".parse().unwrap());
341+
headers
342+
};
343+
344+
server.mock(|when, then| {
345+
when.method(httpmock::Method::POST)
346+
.header(MITHRIL_API_VERSION_HEADER, "1.2.9")
347+
.header("foo", "bar");
348+
then.status(201).body(r#"{"foo": "a", "bar": 1}"#);
349+
});
350+
351+
client
352+
.send(TestPostQuery {
353+
body: TestBody::new("miaouss", 3),
354+
})
355+
.await
356+
.expect("should not fail");
357+
}
358+
308359
#[tokio::test]
309360
async fn test_post_query_timeout() {
310361
let (server, mut client) = setup_server_and_client();

0 commit comments

Comments
 (0)