11use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
22use std:: sync:: Arc ;
33
4+ use bytes:: Bytes ;
45use reqwest:: { Response , StatusCode } ;
56use reqwest_retry:: { default_on_request_failure, default_on_request_success, Retryable } ;
7+ use tokio:: sync:: Mutex ;
68use tokio_retry:: strategy:: { jitter, ExponentialBackoff } ;
79use tokio_retry:: RetryIf ;
810use tracing:: { error, info} ;
911
12+ use crate :: adaptive_concurrency_control:: ConnectionPermit ;
1013use crate :: constants:: { CLIENT_RETRY_BASE_DELAY_MS , CLIENT_RETRY_MAX_ATTEMPTS } ;
1114use crate :: error:: CasClientError ;
1215use crate :: http_client:: request_id_from_response;
@@ -23,6 +26,7 @@ pub struct RetryWrapper {
2326 no_retry_on_429 : bool ,
2427 log_errors_as_info : bool ,
2528 api_tag : & ' static str ,
29+ connection_permit : Option < Mutex < Option < ConnectionPermit > > > ,
2630}
2731
2832impl RetryWrapper {
@@ -33,6 +37,7 @@ impl RetryWrapper {
3337 no_retry_on_429 : false ,
3438 log_errors_as_info : false ,
3539 api_tag,
40+ connection_permit : None ,
3641 }
3742 }
3843
@@ -56,6 +61,11 @@ impl RetryWrapper {
5661 self
5762 }
5863
64+ pub fn with_connection_permit ( mut self , permit : ConnectionPermit ) -> Self {
65+ self . connection_permit = Some ( Mutex :: new ( Some ( permit) ) ) ;
66+ self
67+ }
68+
5969 fn process_error_response ( & self , try_idx : usize , err : reqwest_middleware:: Error ) -> RetryableReqwestError {
6070 let api = & self . api_tag ;
6171
@@ -194,6 +204,12 @@ impl RetryWrapper {
194204 async move {
195205 let ( make_request, process_fn, try_count, self_) = retry_info. as_ref ( ) ;
196206
207+ if let Some ( p) = & self_. connection_permit {
208+ if let Some ( p) = p. lock ( ) . await . as_mut ( ) {
209+ p. transfer_starting ( )
210+ }
211+ }
212+
197213 let resp_result = make_request ( ) . await ;
198214 let try_idx = try_count. fetch_add ( 1 , Ordering :: Relaxed ) ;
199215
@@ -204,10 +220,35 @@ impl RetryWrapper {
204220 Ok ( resp) => self_. process_ok_response ( try_idx, resp) ,
205221 } ;
206222
207- match checked_result {
208- Ok ( ok_response) => process_fn ( ok_response) . await ,
209- Err ( e) => Err ( e) ,
223+ let ( n_bytes, processing_result) = match checked_result {
224+ Ok ( ok_response) => ( ok_response. content_length ( ) . unwrap_or ( 0 ) , process_fn ( ok_response) . await ) ,
225+ Err ( e) => ( 0 , Err ( e) ) ,
226+ } ;
227+
228+ // Now, possibly adjust the connection permit.
229+ if let Some ( permit_holder) = & self_. connection_permit {
230+ let mut maybe_permit = permit_holder. lock ( ) . await ;
231+
232+ match & processing_result {
233+ Ok ( _) => {
234+ if let Some ( permit) = maybe_permit. take ( ) {
235+ permit. report_completion ( n_bytes, true ) . await ;
236+ }
237+ } ,
238+ Err ( RetryableReqwestError :: FatalError ( _) ) => {
239+ if let Some ( permit) = maybe_permit. take ( ) {
240+ permit. report_completion ( 0 , false ) . await ;
241+ }
242+ } ,
243+ Err ( RetryableReqwestError :: RetryableError ( _) ) => {
244+ if let Some ( permit) = maybe_permit. as_ref ( ) {
245+ permit. report_retryable_failure ( ) . await ;
246+ }
247+ } ,
248+ }
210249 }
250+
251+ processing_result
211252 }
212253 } ,
213254 |err : & RetryableReqwestError | matches ! ( err, RetryableReqwestError :: RetryableError ( _) ) ,
@@ -278,6 +319,47 @@ impl RetryWrapper {
278319 . await
279320 }
280321
322+ /// Run a connection and process the result as bytes, retrying on transient errors or on issues not getting the
323+ /// full object.
324+ ///
325+ /// The `make_request` function returns a future that resolves to a Result<Response> object as is returned by the
326+ /// client middleware. For example, `|| client.clone().get(url).send()` returns a future (as `send()` is async)
327+ /// that will then be evaluatated to get the response.
328+ ///
329+ /// This functions acts just like the json() function on a client response, but retries the entire connection on
330+ /// transient errors.
331+ pub async fn run_and_extract_bytes < ReqFut , ReqFn > ( self , make_request : ReqFn ) -> Result < Bytes , CasClientError >
332+ where
333+ ReqFn : Fn ( ) -> ReqFut + Send + ' static ,
334+ ReqFut : std:: future:: Future < Output = Result < Response , reqwest_middleware:: Error > > + ' static ,
335+ {
336+ self . run_and_process ( make_request, |resp : Response | {
337+ async move {
338+ // Extract the bytes from the final result.
339+ let r: Result < Bytes , reqwest:: Error > = resp. bytes ( ) . await ;
340+
341+ match r {
342+ Ok ( v) => Ok ( v) ,
343+ Err ( e) => {
344+ #[ cfg( not( target_arch = "wasm32" ) ) ]
345+ let is_connect = e. is_connect ( ) ;
346+ #[ cfg( target_arch = "wasm32" ) ]
347+ let is_connect = false ;
348+
349+ if is_connect || e. is_decode ( ) || e. is_body ( ) || e. is_timeout ( ) {
350+ // We got an incomplete or corrupted response from the server, possibly due to a dropped
351+ // connection. Presumably this error is transient.
352+ Err ( RetryableReqwestError :: RetryableError ( e. into ( ) ) )
353+ } else {
354+ Err ( RetryableReqwestError :: FatalError ( e. into ( ) ) )
355+ }
356+ } ,
357+ }
358+ }
359+ } )
360+ . await
361+ }
362+
281363 /// Run a connection and process the result object, retrying on transient errors.
282364 ///
283365 /// The `make_request` function returns a future that resolves to a Result<Response> object as is returned by the
0 commit comments