Skip to content

Commit 13833c5

Browse files
committed
Retry all requests in tests
1 parent d64d14a commit 13833c5

File tree

9 files changed

+360
-259
lines changed

9 files changed

+360
-259
lines changed

Cargo.lock

Lines changed: 14 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/dapf/src/functions/helper.rs

Lines changed: 64 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
22
// SPDX-License-Identifier: BSD-3-Clause
33

4-
use anyhow::{anyhow, Context as _};
4+
use anyhow::Context as _;
55
use daphne::{
66
constants::DapMediaType,
7-
error::aborts::ProblemDetails,
87
messages::{
9-
taskprov::TaskprovAdvertisement, AggregateShareReq, AggregationJobInitReq,
8+
taskprov::TaskprovAdvertisement, AggregateShare, AggregateShareReq, AggregationJobInitReq,
109
AggregationJobResp,
1110
},
1211
DapVersion,
1312
};
1413
use daphne_service_utils::{bearer_token::BearerToken, http_headers};
15-
use prio::codec::{ParameterizedDecode as _, ParameterizedEncode as _};
14+
use prio::codec::ParameterizedEncode as _;
1615
use reqwest::header;
1716
use url::Url;
1817

1918
use crate::HttpClient;
2019

21-
use super::response_to_anyhow;
20+
use super::retry_and_decode;
2221

2322
impl HttpClient {
2423
pub async fn submit_aggregation_job_init_req(
@@ -28,45 +27,47 @@ impl HttpClient {
2827
version: DapVersion,
2928
opts: Options<'_>,
3029
) -> anyhow::Result<AggregationJobResp> {
31-
let resp = self
32-
.put(url)
33-
.body(agg_job_init_req.get_encoded_with_param(&version).unwrap())
34-
.headers(construct_request_headers(
35-
DapMediaType::AggregationJobInitReq
36-
.as_str_for_version(version)
37-
.with_context(|| {
38-
format!("AggregationJobInitReq media type is not defined for {version}")
39-
})?,
40-
version,
41-
opts,
42-
)?)
43-
.send()
44-
.await
45-
.context("sending AggregationJobInitReq")?;
46-
if resp.status() == 400 {
47-
let text = resp.text().await?;
48-
let problem_details: ProblemDetails =
49-
serde_json::from_str(&text).with_context(|| {
50-
format!("400 Bad Request: failed to parse problem details document: {text:?}")
51-
})?;
52-
Err(anyhow!("400 Bad Request: {problem_details:?}"))
53-
} else if resp.status() == 500 {
54-
Err(anyhow::anyhow!(
55-
"500 Internal Server Error: {}",
56-
resp.text().await?
57-
))
58-
} else if !resp.status().is_success() {
59-
Err(response_to_anyhow(resp).await).context("while running an AggregationJobInitReq")
60-
} else {
61-
AggregationJobResp::get_decoded_with_param(
62-
&version,
63-
&resp
64-
.bytes()
65-
.await
66-
.context("transfering bytes from the AggregateInitReq")?,
67-
)
68-
.with_context(|| "failed to parse response to AggregateInitReq from Helper")
69-
}
30+
retry_and_decode(&version, || async {
31+
self.put(url.clone())
32+
.body(agg_job_init_req.get_encoded_with_param(&version).unwrap())
33+
.headers(construct_request_headers(
34+
DapMediaType::AggregationJobInitReq
35+
.as_str_for_version(version)
36+
.with_context(|| {
37+
format!("AggregationJobInitReq media type is not defined for {version}")
38+
})?,
39+
version,
40+
opts,
41+
)?)
42+
.send()
43+
.await
44+
.context("sending AggregationJobInitReq")
45+
})
46+
.await
47+
}
48+
49+
pub async fn poll_aggregation_job_init(
50+
&self,
51+
url: Url,
52+
version: DapVersion,
53+
opts: Options<'_>,
54+
) -> anyhow::Result<AggregationJobResp> {
55+
retry_and_decode(&version, || async {
56+
self.get(url.clone())
57+
.headers(construct_request_headers(
58+
DapMediaType::AggregationJobInitReq
59+
.as_str_for_version(version)
60+
.with_context(|| {
61+
format!("AggregationJobInitReq media type is not defined for {version}")
62+
})?,
63+
version,
64+
opts,
65+
)?)
66+
.send()
67+
.await
68+
.context("polling aggregation job init req")
69+
})
70+
.await
7071
}
7172

7273
pub async fn get_aggregate_share(
@@ -75,42 +76,28 @@ impl HttpClient {
7576
agg_share_req: AggregateShareReq,
7677
version: DapVersion,
7778
opts: Options<'_>,
78-
) -> anyhow::Result<()> {
79-
let resp = self
80-
.post(url)
81-
.body(agg_share_req.get_encoded_with_param(&version).unwrap())
82-
.headers(construct_request_headers(
83-
DapMediaType::AggregateShareReq
84-
.as_str_for_version(version)
85-
.with_context(|| {
86-
format!("AggregateShareReq media type is not defined for {version}")
87-
})?,
88-
version,
89-
opts,
90-
)?)
91-
.send()
92-
.await
93-
.context("sending AggregateShareReq")?;
94-
if resp.status() == 400 {
95-
let problem_details: ProblemDetails = serde_json::from_slice(
96-
&resp
97-
.bytes()
98-
.await
99-
.context("transfering bytes for AggregateShareReq")?,
100-
)
101-
.with_context(|| "400 Bad Request: failed to parse problem details document")?;
102-
Err(anyhow!("400 Bad Request: {problem_details:?}"))
103-
} else if resp.status() == 500 {
104-
Err(anyhow!("500 Internal Server Error: {}", resp.text().await?))
105-
} else if !resp.status().is_success() {
106-
Err(response_to_anyhow(resp).await).context("while running an AggregateShareReq")
107-
} else {
108-
Ok(())
109-
}
79+
) -> anyhow::Result<AggregateShare> {
80+
retry_and_decode(&(), || async {
81+
self.post(url.clone())
82+
.body(agg_share_req.get_encoded_with_param(&version).unwrap())
83+
.headers(construct_request_headers(
84+
DapMediaType::AggregateShareReq
85+
.as_str_for_version(version)
86+
.with_context(|| {
87+
format!("AggregateShareReq media type is not defined for {version}")
88+
})?,
89+
version,
90+
opts,
91+
)?)
92+
.send()
93+
.await
94+
.context("sending AggregateShareReq")
95+
})
96+
.await
11097
}
11198
}
11299

113-
#[derive(Default, Debug)]
100+
#[derive(Default, Debug, Clone, Copy)]
114101
pub struct Options<'s> {
115102
pub taskprov_advertisement: Option<&'s TaskprovAdvertisement>,
116103
pub bearer_token: Option<&'s BearerToken>,

crates/dapf/src/functions/hpke.rs

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use x509_parser::pem::Pem;
1313

1414
use crate::HttpClient;
1515

16-
use super::response_to_anyhow;
16+
use super::retry;
1717

1818
impl HttpClient {
1919
pub async fn get_hpke_config(
@@ -22,36 +22,40 @@ impl HttpClient {
2222
certificate_file: Option<&Path>,
2323
) -> anyhow::Result<HpkeConfigList> {
2424
let url = base_url.join("hpke_config")?;
25-
let resp = self
26-
.get(url.as_str())
27-
.send()
28-
.await
29-
.with_context(|| "request failed")?;
30-
if !resp.status().is_success() {
31-
return Err(response_to_anyhow(resp).await);
32-
}
33-
let maybe_signature = resp.headers().get(http_headers::HPKE_SIGNATURE).cloned();
34-
let hpke_config_bytes = resp.bytes().await.context("failed to read hpke config")?;
35-
if let Some(cert_path) = certificate_file {
36-
let cert = std::fs::read_to_string(cert_path).context("reading the certificate")?;
37-
let Some(signature) = maybe_signature else {
38-
anyhow::bail!("Aggregator did not sign its response");
39-
};
40-
let signature_bytes =
41-
decode_base64url_vec(signature.as_bytes()).context("decoding the signature")?;
42-
let (cert_pem, _bytes_read) =
43-
Pem::read(Cursor::new(cert.as_bytes())).context("reading PEM certificate")?;
44-
let cert = EndEntityCert::try_from(cert_pem.contents.as_ref())
45-
.map_err(|e| anyhow!("{e:?}")) // webpki::Error does not implement std::error::Error
46-
.context("parsing PEM certificate")?;
25+
retry(
26+
|| async {
27+
self.get(url.as_str())
28+
.send()
29+
.await
30+
.with_context(|| "request failed")
31+
},
32+
|resp| async {
33+
let maybe_signature = resp.headers().get(http_headers::HPKE_SIGNATURE).cloned();
34+
let hpke_config_bytes = resp.bytes().await.context("failed to read hpke config")?;
35+
if let Some(cert_path) = certificate_file {
36+
let cert =
37+
std::fs::read_to_string(cert_path).context("reading the certificate")?;
38+
let Some(signature) = maybe_signature else {
39+
anyhow::bail!("Aggregator did not sign its response");
40+
};
41+
let signature_bytes = decode_base64url_vec(signature.as_bytes())
42+
.context("decoding the signature")?;
43+
let (cert_pem, _bytes_read) = Pem::read(Cursor::new(cert.as_bytes()))
44+
.context("reading PEM certificate")?;
45+
let cert = EndEntityCert::try_from(cert_pem.contents.as_ref())
46+
.map_err(|e| anyhow!("{e:?}")) // webpki::Error does not implement std::error::Error
47+
.context("parsing PEM certificate")?;
4748

48-
cert.verify_signature(
49-
&ECDSA_P256_SHA256,
50-
&hpke_config_bytes,
51-
signature_bytes.as_ref(),
52-
)
53-
.map_err(|e| anyhow!("signature not verified: {}", e.to_string()))?;
54-
}
55-
Ok(HpkeConfigList::get_decoded(&hpke_config_bytes)?)
49+
cert.verify_signature(
50+
&ECDSA_P256_SHA256,
51+
&hpke_config_bytes,
52+
signature_bytes.as_ref(),
53+
)
54+
.map_err(|e| anyhow!("signature not verified: {}", e.to_string()))?;
55+
}
56+
Ok(HpkeConfigList::get_decoded(&hpke_config_bytes)?)
57+
},
58+
)
59+
.await
5660
}
5761
}

0 commit comments

Comments
 (0)