@@ -6,6 +6,10 @@ use crate::llm::{
66} ;
77use base64:: prelude:: * ;
88use google_cloud_aiplatform_v1 as vertexai;
9+ use google_cloud_gax:: exponential_backoff:: ExponentialBackoff ;
10+ use google_cloud_gax:: options:: RequestOptionsBuilder ;
11+ use google_cloud_gax:: retry_policy:: { Aip194Strict , RetryPolicyExt } ;
12+ use google_cloud_gax:: retry_throttler:: { AdaptiveThrottler , SharedRetryThrottler } ;
913use serde_json:: Value ;
1014use urlencoding:: encode;
1115
@@ -237,6 +241,33 @@ pub struct VertexAiClient {
237241 config : super :: VertexAiConfig ,
238242}
239243
244+ #[ derive( Debug ) ]
245+ struct CustomizedGoogleCloudRetryPolicy ;
246+
247+ impl google_cloud_gax:: retry_policy:: RetryPolicy for CustomizedGoogleCloudRetryPolicy {
248+ fn on_error (
249+ & self ,
250+ state : & google_cloud_gax:: retry_state:: RetryState ,
251+ error : google_cloud_gax:: error:: Error ,
252+ ) -> google_cloud_gax:: retry_result:: RetryResult {
253+ use google_cloud_gax:: retry_result:: RetryResult ;
254+
255+ if let Some ( status) = error. status ( ) {
256+ if status. code == google_cloud_gax:: error:: rpc:: Code :: ResourceExhausted {
257+ return RetryResult :: Continue ( error) ;
258+ }
259+ } else if let Some ( code) = error. http_status_code ( )
260+ && code == reqwest:: StatusCode :: TOO_MANY_REQUESTS . as_u16 ( )
261+ {
262+ return RetryResult :: Continue ( error) ;
263+ }
264+ Aip194Strict . on_error ( state, error)
265+ }
266+ }
267+
268+ static SHARED_RETRY_THROTTLER : LazyLock < SharedRetryThrottler > =
269+ LazyLock :: new ( || Arc :: new ( Mutex :: new ( AdaptiveThrottler :: new ( 2.0 ) . unwrap ( ) ) ) ) ;
270+
240271impl VertexAiClient {
241272 pub async fn new (
242273 address : Option < String > ,
@@ -249,6 +280,11 @@ impl VertexAiClient {
249280 api_bail ! ( "VertexAi API config is required for VertexAi API type" ) ;
250281 } ;
251282 let client = vertexai:: client:: PredictionService :: builder ( )
283+ . with_retry_policy (
284+ CustomizedGoogleCloudRetryPolicy . with_time_limit ( retryable:: DEFAULT_RETRY_TIMEOUT ) ,
285+ )
286+ . with_backoff_policy ( ExponentialBackoff :: default ( ) )
287+ . with_retry_throttler ( SHARED_RETRY_THROTTLER . clone ( ) )
252288 . build ( )
253289 . await ?;
254290 Ok ( Self { client, config } )
@@ -312,7 +348,8 @@ impl LlmGenerationClient for VertexAiClient {
312348 . client
313349 . generate_content ( )
314350 . set_model ( self . get_model_path ( request. model ) )
315- . set_contents ( contents) ;
351+ . set_contents ( contents)
352+ . with_idempotency ( true ) ;
316353 if let Some ( sys) = system_instruction {
317354 req = req. set_system_instruction ( sys) ;
318355 }
@@ -376,6 +413,7 @@ impl LlmEmbeddingClient for VertexAiClient {
376413 . set_endpoint ( self . get_model_path ( request. model ) )
377414 . set_instances ( instances)
378415 . set_parameters ( parameters)
416+ . with_idempotency ( true )
379417 . send ( )
380418 . await ?;
381419
0 commit comments