11// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
22// SPDX-License-Identifier: BSD-3-Clause
33
4- use anyhow:: { anyhow, Context } ;
4+ use anyhow:: { anyhow, bail , Context } ;
55use daphne:: {
66 constants:: DapMediaType ,
77 error:: aborts:: ProblemDetails ,
88 messages:: {
9- taskprov:: TaskprovAdvertisement , AggregateShareReq , AggregationJobInitReq ,
9+ taskprov:: TaskprovAdvertisement , AggregateShare , AggregateShareReq , AggregationJobInitReq ,
1010 AggregationJobResp ,
1111 } ,
1212 DapVersion ,
@@ -19,6 +19,7 @@ use url::Url;
1919use crate :: HttpClient ;
2020
2121use super :: response_to_anyhow;
22+ use std:: ops:: ControlFlow ;
2223
2324impl HttpClient {
2425 pub async fn submit_aggregation_job_init_req (
@@ -28,22 +29,23 @@ impl HttpClient {
2829 version : DapVersion ,
2930 opts : Options < ' _ > ,
3031 ) -> anyhow:: Result < AggregationJobResp > {
31- let resp = self
32- . put ( url)
33- . body ( agg_job_init_req. get_encoded_with_param ( & version) . unwrap ( ) )
34- . headers ( construct_request_headers (
35- DapMediaType :: AggregationJobInitReq
36- . as_str_for_version ( version)
37- . with_context ( || {
38- format ! ( "AggregationJobInitReq media type is not defined for {version}" )
39- } ) ?,
40- version,
41- opts,
42- ) ?)
43- . send ( )
44- . await
45- . context ( "sending AggregationJobInitReq" ) ?;
46- handle_response ( resp, & version) . await
32+ retry ( & version, || async {
33+ self . put ( url. clone ( ) )
34+ . body ( agg_job_init_req. get_encoded_with_param ( & version) . unwrap ( ) )
35+ . headers ( construct_request_headers (
36+ DapMediaType :: AggregationJobInitReq
37+ . as_str_for_version ( version)
38+ . with_context ( || {
39+ format ! ( "AggregationJobInitReq media type is not defined for {version}" )
40+ } ) ?,
41+ version,
42+ opts,
43+ ) ?)
44+ . send ( )
45+ . await
46+ . context ( "sending AggregationJobInitReq" )
47+ } )
48+ . await
4749 }
4850
4951 pub async fn poll_aggregation_job_init (
@@ -52,21 +54,22 @@ impl HttpClient {
5254 version : DapVersion ,
5355 opts : Options < ' _ > ,
5456 ) -> anyhow:: Result < AggregationJobResp > {
55- let resp = self
56- . get ( url)
57- . headers ( construct_request_headers (
58- DapMediaType :: AggregationJobInitReq
59- . as_str_for_version ( version)
60- . with_context ( || {
61- format ! ( "AggregationJobInitReq media type is not defined for {version}" )
62- } ) ?,
63- version,
64- opts,
65- ) ?)
66- . send ( )
67- . await
68- . context ( "polling aggregation job init req" ) ?;
69- handle_response ( resp, & version) . await
57+ retry ( & version, || async {
58+ self . get ( url. clone ( ) )
59+ . headers ( construct_request_headers (
60+ DapMediaType :: AggregationJobInitReq
61+ . as_str_for_version ( version)
62+ . with_context ( || {
63+ format ! ( "AggregationJobInitReq media type is not defined for {version}" )
64+ } ) ?,
65+ version,
66+ opts,
67+ ) ?)
68+ . send ( )
69+ . await
70+ . context ( "polling aggregation job init req" )
71+ } )
72+ . await
7073 }
7174
7275 pub async fn get_aggregate_share (
@@ -75,42 +78,28 @@ impl HttpClient {
7578 agg_share_req : AggregateShareReq ,
7679 version : DapVersion ,
7780 opts : Options < ' _ > ,
78- ) -> anyhow:: Result < ( ) > {
79- let resp = self
80- . post ( url)
81- . body ( agg_share_req. get_encoded_with_param ( & version) . unwrap ( ) )
82- . headers ( construct_request_headers (
83- DapMediaType :: AggregateShareReq
84- . as_str_for_version ( version)
85- . with_context ( || {
86- format ! ( "AggregateShareReq media type is not defined for {version}" )
87- } ) ?,
88- version,
89- opts,
90- ) ?)
91- . send ( )
92- . await
93- . context ( "sending AggregateShareReq" ) ?;
94- if resp. status ( ) == 400 {
95- let problem_details: ProblemDetails = serde_json:: from_slice (
96- & resp
97- . bytes ( )
98- . await
99- . context ( "transfering bytes for AggregateShareReq" ) ?,
100- )
101- . with_context ( || "400 Bad Request: failed to parse problem details document" ) ?;
102- Err ( anyhow ! ( "400 Bad Request: {problem_details:?}" ) )
103- } else if resp. status ( ) == 500 {
104- Err ( anyhow ! ( "500 Internal Server Error: {}" , resp. text( ) . await ?) )
105- } else if !resp. status ( ) . is_success ( ) {
106- Err ( response_to_anyhow ( resp) . await ) . context ( "while running an AggregateShareReq" )
107- } else {
108- Ok ( ( ) )
109- }
81+ ) -> anyhow:: Result < AggregateShare > {
82+ retry ( & ( ) , || async {
83+ self . post ( url. clone ( ) )
84+ . body ( agg_share_req. get_encoded_with_param ( & version) . unwrap ( ) )
85+ . headers ( construct_request_headers (
86+ DapMediaType :: AggregateShareReq
87+ . as_str_for_version ( version)
88+ . with_context ( || {
89+ format ! ( "AggregateShareReq media type is not defined for {version}" )
90+ } ) ?,
91+ version,
92+ opts,
93+ ) ?)
94+ . send ( )
95+ . await
96+ . context ( "sending AggregateShareReq" )
97+ } )
98+ . await
11099 }
111100}
112101
113- #[ derive( Default , Debug ) ]
102+ #[ derive( Default , Debug , Clone , Copy ) ]
114103pub struct Options < ' s > {
115104 pub taskprov_advertisement : Option < & ' s TaskprovAdvertisement > ,
116105 pub bearer_token : Option < & ' s BearerToken > ,
@@ -145,10 +134,35 @@ fn construct_request_headers(
145134 Ok ( headers)
146135}
147136
148- async fn handle_response < R , P > ( resp : reqwest:: Response , params : & P ) -> anyhow:: Result < R >
137+ async fn retry < F , Fut , R , P > ( params : & P , mut f : F ) -> anyhow:: Result < R >
138+ where
139+ F : FnMut ( ) -> Fut ,
140+ Fut : std:: future:: Future < Output = anyhow:: Result < reqwest:: Response > > ,
141+ R : ParameterizedDecode < P > ,
142+ {
143+ const RETRY_COUNT : usize = 5 ;
144+ for i in 1 ..=RETRY_COUNT {
145+ let resp = f ( ) . await ?;
146+ match handle_response ( resp, params) . await ? {
147+ ControlFlow :: Continue ( ( ) ) if i == RETRY_COUNT => bail ! ( "service unavailable" ) ,
148+ ControlFlow :: Continue ( ( ) ) => {
149+ tracing:: info!( "retrying...." ) ;
150+ }
151+ ControlFlow :: Break ( r) => return Ok ( r) ,
152+ }
153+ }
154+ unreachable ! ( )
155+ }
156+
157+ async fn handle_response < R , P > (
158+ resp : reqwest:: Response ,
159+ params : & P ,
160+ ) -> anyhow:: Result < ControlFlow < R > >
149161where
150162 R : ParameterizedDecode < P > ,
151163{
164+ let output_type = std:: any:: type_name :: < R > ( ) ;
165+
152166 if resp. status ( ) == 400 {
153167 let text = resp. text ( ) . await ?;
154168 let problem_details: ProblemDetails = serde_json:: from_str ( & text) . with_context ( || {
@@ -160,16 +174,19 @@ where
160174 "500 Internal Server Error: {}" ,
161175 resp. text( ) . await ?
162176 ) )
177+ } else if resp. status ( ) == 503 {
178+ return Ok ( ControlFlow :: Continue ( ( ) ) ) ;
163179 } else if !resp. status ( ) . is_success ( ) {
164- Err ( response_to_anyhow ( resp) . await ) . context ( "while running an AggregationJobInitReq" )
180+ Err ( response_to_anyhow ( resp) . await )
165181 } else {
166- R :: get_decoded_with_param (
167- params,
168- & resp
169- . bytes ( )
170- . await
171- . context ( "transfering bytes from the AggregateInitReq" ) ?,
172- )
173- . with_context ( || "failed to parse response to AggregateInitReq from Helper" )
182+ let bytes = resp
183+ . bytes ( )
184+ . await
185+ . with_context ( || format ! ( "transfering bytes from the {output_type}" ) ) ?;
186+
187+ R :: get_decoded_with_param ( params, & bytes)
188+ . with_context ( || format ! ( "failed to parse response to {output_type} from Helper" ) )
189+ . with_context ( || format ! ( "faulty bytes: {bytes:?}" ) )
190+ . map ( ControlFlow :: Break )
174191 }
175192}
0 commit comments