Skip to content

Commit f2b2e34

Browse files
committed
WIP
1 parent 13dd33c commit f2b2e34

File tree

3 files changed

+102
-76
lines changed

3 files changed

+102
-76
lines changed

crates/dapf/src/acceptance/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,20 @@ impl Test {
526526
.await?;
527527
let duration = start.elapsed();
528528
info!("Finished submitting AggregationJobInitReq in {duration:#?}");
529+
let mut poll_count = 0;
529530
let ready = loop {
530531
agg_job_resp = match agg_job_resp {
531532
messages::AggregationJobResp::Ready { transitions } => {
532533
break ReadyAggregationJobResp { transitions }
533534
}
534535
messages::AggregationJobResp::Processing => {
536+
tokio::time::sleep(Duration::from_millis(if poll_count == 0 {
537+
20_000
538+
} else {
539+
poll_count * 100
540+
}))
541+
.await;
542+
poll_count += 1;
535543
self.http_client
536544
.poll_aggregation_job_init(
537545
self.helper_url
Lines changed: 93 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
22
// SPDX-License-Identifier: BSD-3-Clause
33

4-
use anyhow::{anyhow, Context};
4+
use anyhow::{anyhow, bail, Context};
55
use daphne::{
66
constants::DapMediaType,
77
error::aborts::ProblemDetails,
88
messages::{
9-
taskprov::TaskprovAdvertisement, AggregateShareReq, AggregationJobInitReq,
9+
taskprov::TaskprovAdvertisement, AggregateShare, AggregateShareReq, AggregationJobInitReq,
1010
AggregationJobResp,
1111
},
1212
DapVersion,
@@ -19,6 +19,7 @@ use url::Url;
1919
use crate::HttpClient;
2020

2121
use super::response_to_anyhow;
22+
use std::ops::ControlFlow;
2223

2324
impl HttpClient {
2425
pub async fn submit_aggregation_job_init_req(
@@ -28,22 +29,23 @@ impl HttpClient {
2829
version: DapVersion,
2930
opts: Options<'_>,
3031
) -> anyhow::Result<AggregationJobResp> {
31-
let resp = self
32-
.put(url)
33-
.body(agg_job_init_req.get_encoded_with_param(&version).unwrap())
34-
.headers(construct_request_headers(
35-
DapMediaType::AggregationJobInitReq
36-
.as_str_for_version(version)
37-
.with_context(|| {
38-
format!("AggregationJobInitReq media type is not defined for {version}")
39-
})?,
40-
version,
41-
opts,
42-
)?)
43-
.send()
44-
.await
45-
.context("sending AggregationJobInitReq")?;
46-
handle_response(resp, &version).await
32+
retry(&version, || async {
33+
self.put(url.clone())
34+
.body(agg_job_init_req.get_encoded_with_param(&version).unwrap())
35+
.headers(construct_request_headers(
36+
DapMediaType::AggregationJobInitReq
37+
.as_str_for_version(version)
38+
.with_context(|| {
39+
format!("AggregationJobInitReq media type is not defined for {version}")
40+
})?,
41+
version,
42+
opts,
43+
)?)
44+
.send()
45+
.await
46+
.context("sending AggregationJobInitReq")
47+
})
48+
.await
4749
}
4850

4951
pub async fn poll_aggregation_job_init(
@@ -52,21 +54,22 @@ impl HttpClient {
5254
version: DapVersion,
5355
opts: Options<'_>,
5456
) -> anyhow::Result<AggregationJobResp> {
55-
let resp = self
56-
.get(url)
57-
.headers(construct_request_headers(
58-
DapMediaType::AggregationJobInitReq
59-
.as_str_for_version(version)
60-
.with_context(|| {
61-
format!("AggregationJobInitReq media type is not defined for {version}")
62-
})?,
63-
version,
64-
opts,
65-
)?)
66-
.send()
67-
.await
68-
.context("polling aggregation job init req")?;
69-
handle_response(resp, &version).await
57+
retry(&version, || async {
58+
self.get(url.clone())
59+
.headers(construct_request_headers(
60+
DapMediaType::AggregationJobInitReq
61+
.as_str_for_version(version)
62+
.with_context(|| {
63+
format!("AggregationJobInitReq media type is not defined for {version}")
64+
})?,
65+
version,
66+
opts,
67+
)?)
68+
.send()
69+
.await
70+
.context("polling aggregation job init req")
71+
})
72+
.await
7073
}
7174

7275
pub async fn get_aggregate_share(
@@ -75,42 +78,28 @@ impl HttpClient {
7578
agg_share_req: AggregateShareReq,
7679
version: DapVersion,
7780
opts: Options<'_>,
78-
) -> anyhow::Result<()> {
79-
let resp = self
80-
.post(url)
81-
.body(agg_share_req.get_encoded_with_param(&version).unwrap())
82-
.headers(construct_request_headers(
83-
DapMediaType::AggregateShareReq
84-
.as_str_for_version(version)
85-
.with_context(|| {
86-
format!("AggregateShareReq media type is not defined for {version}")
87-
})?,
88-
version,
89-
opts,
90-
)?)
91-
.send()
92-
.await
93-
.context("sending AggregateShareReq")?;
94-
if resp.status() == 400 {
95-
let problem_details: ProblemDetails = serde_json::from_slice(
96-
&resp
97-
.bytes()
98-
.await
99-
.context("transfering bytes for AggregateShareReq")?,
100-
)
101-
.with_context(|| "400 Bad Request: failed to parse problem details document")?;
102-
Err(anyhow!("400 Bad Request: {problem_details:?}"))
103-
} else if resp.status() == 500 {
104-
Err(anyhow!("500 Internal Server Error: {}", resp.text().await?))
105-
} else if !resp.status().is_success() {
106-
Err(response_to_anyhow(resp).await).context("while running an AggregateShareReq")
107-
} else {
108-
Ok(())
109-
}
81+
) -> anyhow::Result<AggregateShare> {
82+
retry(&(), || async {
83+
self.post(url.clone())
84+
.body(agg_share_req.get_encoded_with_param(&version).unwrap())
85+
.headers(construct_request_headers(
86+
DapMediaType::AggregateShareReq
87+
.as_str_for_version(version)
88+
.with_context(|| {
89+
format!("AggregateShareReq media type is not defined for {version}")
90+
})?,
91+
version,
92+
opts,
93+
)?)
94+
.send()
95+
.await
96+
.context("sending AggregateShareReq")
97+
})
98+
.await
11099
}
111100
}
112101

113-
#[derive(Default, Debug)]
102+
#[derive(Default, Debug, Clone, Copy)]
114103
pub struct Options<'s> {
115104
pub taskprov_advertisement: Option<&'s TaskprovAdvertisement>,
116105
pub bearer_token: Option<&'s BearerToken>,
@@ -145,10 +134,35 @@ fn construct_request_headers(
145134
Ok(headers)
146135
}
147136

148-
async fn handle_response<R, P>(resp: reqwest::Response, params: &P) -> anyhow::Result<R>
137+
async fn retry<F, Fut, R, P>(params: &P, mut f: F) -> anyhow::Result<R>
138+
where
139+
F: FnMut() -> Fut,
140+
Fut: std::future::Future<Output = anyhow::Result<reqwest::Response>>,
141+
R: ParameterizedDecode<P>,
142+
{
143+
const RETRY_COUNT: usize = 5;
144+
for i in 1..=RETRY_COUNT {
145+
let resp = f().await?;
146+
match handle_response(resp, params).await? {
147+
ControlFlow::Continue(()) if i == RETRY_COUNT => bail!("service unavailable"),
148+
ControlFlow::Continue(()) => {
149+
tracing::info!("retrying....");
150+
}
151+
ControlFlow::Break(r) => return Ok(r),
152+
}
153+
}
154+
unreachable!()
155+
}
156+
157+
async fn handle_response<R, P>(
158+
resp: reqwest::Response,
159+
params: &P,
160+
) -> anyhow::Result<ControlFlow<R>>
149161
where
150162
R: ParameterizedDecode<P>,
151163
{
164+
let output_type = std::any::type_name::<R>();
165+
152166
if resp.status() == 400 {
153167
let text = resp.text().await?;
154168
let problem_details: ProblemDetails = serde_json::from_str(&text).with_context(|| {
@@ -160,16 +174,19 @@ where
160174
"500 Internal Server Error: {}",
161175
resp.text().await?
162176
))
177+
} else if resp.status() == 503 {
178+
return Ok(ControlFlow::Continue(()));
163179
} else if !resp.status().is_success() {
164-
Err(response_to_anyhow(resp).await).context("while running an AggregationJobInitReq")
180+
Err(response_to_anyhow(resp).await)
165181
} else {
166-
R::get_decoded_with_param(
167-
params,
168-
&resp
169-
.bytes()
170-
.await
171-
.context("transfering bytes from the AggregateInitReq")?,
172-
)
173-
.with_context(|| "failed to parse response to AggregateInitReq from Helper")
182+
let bytes = resp
183+
.bytes()
184+
.await
185+
.with_context(|| format!("transfering bytes from the {output_type}"))?;
186+
187+
R::get_decoded_with_param(params, &bytes)
188+
.with_context(|| format!("failed to parse response to {output_type} from Helper"))
189+
.with_context(|| format!("faulty bytes: {bytes:?}"))
190+
.map(ControlFlow::Break)
174191
}
175192
}

crates/daphne-worker-test/wrangler.aggregator.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ fallthrough = false
2525
name = "daphne-helper-aggregator"
2626

2727
[env.helper.vars]
28+
DAP_TRACING="debug"
2829
DAP_DEPLOYMENT = "dev"
2930
DAP_WORKER_MODE = "aggregator"
3031
DAP_DURABLE_HELPER_STATE_STORE_GC_AFTER_SECS = "30"

0 commit comments

Comments
 (0)