Skip to content

Commit 50c9473

Browse files
committed
feat(aggregator-client): add timeout support
1 parent f2f621a commit 50c9473

File tree

2 files changed

+68
-19
lines changed

2 files changed

+68
-19
lines changed

internal/mithril-aggregator-client/src/builder.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use anyhow::Context;
22
use reqwest::{IntoUrl, Url};
33
use slog::{Logger, o};
4+
use std::time::Duration;
45

56
use mithril_common::StdResult;
67
use mithril_common::api_version::APIVersionProvider;
@@ -11,6 +12,7 @@ use crate::client::AggregatorClient;
1112
pub struct AggregatorClientBuilder {
1213
aggregator_url_result: reqwest::Result<Url>,
1314
api_version_provider: Option<APIVersionProvider>,
15+
timeout_duration: Option<Duration>,
1416
logger: Option<Logger>,
1517
}
1618

@@ -22,6 +24,7 @@ impl AggregatorClientBuilder {
2224
Self {
2325
aggregator_url_result: aggregator_url.into_url(),
2426
api_version_provider: None,
27+
timeout_duration: None,
2528
logger: None,
2629
}
2730
}
@@ -38,6 +41,12 @@ impl AggregatorClientBuilder {
3841
self
3942
}
4043

44+
/// Set a timeout to enforce on each request
45+
pub fn with_timeout(mut self, timeout: Duration) -> Self {
46+
self.timeout_duration = Some(timeout);
47+
self
48+
}
49+
4150
/// Returns an [AggregatorClient] based on the builder configuration
4251
pub fn build(self) -> StdResult<AggregatorClient> {
4352
let aggregator_endpoint =
@@ -50,6 +59,7 @@ impl AggregatorClientBuilder {
5059
Ok(AggregatorClient {
5160
aggregator_endpoint,
5261
api_version_provider,
62+
timeout_duration: self.timeout_duration,
5363
client: reqwest::Client::new(),
5464
logger,
5565
})

internal/mithril-aggregator-client/src/client.rs

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use anyhow::{Context, anyhow};
22
use reqwest::{IntoUrl, Response, Url};
33
use semver::Version;
44
use slog::{Logger, error, warn};
5+
use std::time::Duration;
56

67
use mithril_common::MITHRIL_API_VERSION_HEADER;
78
use mithril_common::api_version::APIVersionProvider;
@@ -16,6 +17,7 @@ const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incom
1617
pub struct AggregatorClient {
1718
pub(super) aggregator_endpoint: Url,
1819
pub(super) api_version_provider: APIVersionProvider,
20+
pub(super) timeout_duration: Option<Duration>,
1921
pub(super) client: reqwest::Client,
2022
pub(super) logger: Logger,
2123
}
@@ -41,6 +43,10 @@ impl AggregatorClient {
4143
request_builder = request_builder.json(&body);
4244
}
4345

46+
if let Some(timeout) = self.timeout_duration {
47+
request_builder = request_builder.timeout(timeout);
48+
}
49+
4450
match request_builder.send().await {
4551
Ok(response) => {
4652
self.warn_if_api_version_mismatch(&response);
@@ -159,6 +165,15 @@ mod tests {
159165
chu: u8,
160166
}
161167

168+
impl TestBody {
169+
fn new<P: Into<String>>(pika: P, chu: u8) -> Self {
170+
Self {
171+
pika: pika.into(),
172+
chu,
173+
}
174+
}
175+
}
176+
162177
struct TestPostQuery {
163178
body: TestBody,
164179
}
@@ -226,6 +241,23 @@ mod tests {
226241

227242
client.send(TestGetQuery).await.expect("should not fail");
228243
}
244+
245+
#[tokio::test]
246+
async fn test_get_query_timeout() {
247+
let (server, mut client) = setup_server_and_client();
248+
client.timeout_duration = Some(Duration::from_millis(10));
249+
let _server_mock = server.mock(|when, then| {
250+
when.method(httpmock::Method::GET);
251+
then.delay(Duration::from_millis(100));
252+
});
253+
254+
let error = client.send(TestGetQuery).await.expect_err("should not fail");
255+
256+
assert!(
257+
matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
258+
"unexpected error type: {error:?}"
259+
);
260+
}
229261
}
230262

231263
mod post {
@@ -238,22 +270,13 @@ mod tests {
238270
when.method(httpmock::Method::POST)
239271
.path("/dummy-post-route")
240272
.header("content-type", "application/json")
241-
.body(
242-
serde_json::to_string(&TestBody {
243-
pika: "miaouss".to_string(),
244-
chu: 5,
245-
})
246-
.unwrap(),
247-
);
273+
.body(serde_json::to_string(&TestBody::new("miaouss", 5)).unwrap());
248274
then.status(201);
249275
});
250276

251277
let response = client
252278
.send(TestPostQuery {
253-
body: TestBody {
254-
pika: "miaouss".to_string(),
255-
chu: 5,
256-
},
279+
body: TestBody::new("miaouss", 5),
257280
})
258281
.await
259282
.unwrap();
@@ -274,14 +297,33 @@ mod tests {
274297

275298
client
276299
.send(TestPostQuery {
277-
body: TestBody {
278-
pika: "a".to_string(),
279-
chu: 3,
280-
},
300+
body: TestBody::new("miaouss", 3),
281301
})
282302
.await
283303
.expect("should not fail");
284304
}
305+
306+
#[tokio::test]
307+
async fn test_post_query_timeout() {
308+
let (server, mut client) = setup_server_and_client();
309+
client.timeout_duration = Some(Duration::from_millis(10));
310+
let _server_mock = server.mock(|when, then| {
311+
when.method(httpmock::Method::POST);
312+
then.delay(Duration::from_millis(100));
313+
});
314+
315+
let error = client
316+
.send(TestPostQuery {
317+
body: TestBody::new("miaouss", 3),
318+
})
319+
.await
320+
.expect_err("should not fail");
321+
322+
assert!(
323+
matches!(error, AggregatorClientError::RemoteServerUnreachable(_)),
324+
"unexpected error type: {error:?}"
325+
);
326+
}
285327
}
286328

287329
mod warn_if_api_version_mismatch {
@@ -481,10 +523,7 @@ mod tests {
481523

482524
client
483525
.send(TestPostQuery {
484-
body: TestBody {
485-
pika: "miaouss".to_string(),
486-
chu: 5,
487-
},
526+
body: TestBody::new("miaouss", 3),
488527
})
489528
.await
490529
.unwrap();

0 commit comments

Comments
 (0)