@@ -6,6 +6,9 @@ use crate::llm::{
66} ;
77use base64:: prelude:: * ;
88use google_cloud_aiplatform_v1 as vertexai;
9+ use google_cloud_gax:: exponential_backoff:: ExponentialBackoff ;
10+ use google_cloud_gax:: retry_policy:: { Aip194Strict , RetryPolicyExt } ;
11+ use google_cloud_gax:: retry_throttler:: { AdaptiveThrottler , SharedRetryThrottler } ;
912use serde_json:: Value ;
1013use urlencoding:: encode;
1114
@@ -237,6 +240,36 @@ pub struct VertexAiClient {
237240 config : super :: VertexAiConfig ,
238241}
239242
243+ #[ derive( Debug ) ]
244+ struct CustomizedGoogleCloudRetryPolicy ;
245+
246+ impl google_cloud_gax:: retry_policy:: RetryPolicy for CustomizedGoogleCloudRetryPolicy {
247+ fn on_error (
248+ & self ,
249+ state : & google_cloud_gax:: retry_state:: RetryState ,
250+ error : google_cloud_gax:: error:: Error ,
251+ ) -> google_cloud_gax:: retry_result:: RetryResult {
252+ use google_cloud_gax:: retry_result:: RetryResult ;
253+
254+ if !state. idempotent {
255+ return RetryResult :: Permanent ( error) ;
256+ }
257+ if let Some ( status) = error. status ( ) {
258+ if status. code == google_cloud_gax:: error:: rpc:: Code :: ResourceExhausted {
259+ return RetryResult :: Continue ( error) ;
260+ }
261+ } else if let Some ( code) = error. http_status_code ( )
262+ && code == reqwest:: StatusCode :: TOO_MANY_REQUESTS . as_u16 ( )
263+ {
264+ return RetryResult :: Continue ( error) ;
265+ }
266+ Aip194Strict . on_error ( state, error)
267+ }
268+ }
269+
270+ static SHARED_RETRY_THROTTLER : LazyLock < SharedRetryThrottler > =
271+ LazyLock :: new ( || Arc :: new ( Mutex :: new ( AdaptiveThrottler :: new ( 2.0 ) . unwrap ( ) ) ) ) ;
272+
240273impl VertexAiClient {
241274 pub async fn new (
242275 address : Option < String > ,
@@ -249,6 +282,11 @@ impl VertexAiClient {
249282 api_bail ! ( "VertexAi API config is required for VertexAi API type" ) ;
250283 } ;
251284 let client = vertexai:: client:: PredictionService :: builder ( )
285+ . with_retry_policy (
286+ CustomizedGoogleCloudRetryPolicy . with_time_limit ( retryable:: DEFAULT_RETRY_TIMEOUT ) ,
287+ )
288+ . with_backoff_policy ( ExponentialBackoff :: default ( ) )
289+ . with_retry_throttler ( SHARED_RETRY_THROTTLER . clone ( ) )
252290 . build ( )
253291 . await ?;
254292 Ok ( Self { client, config } )
0 commit comments