Skip to content

Commit 1e20bb0

Browse files
committed
fix(vertexai-retry): retry properly for Vertex AI
1 parent a0a2abe commit 1e20bb0

File tree

5 files changed

+109
-57
lines changed

5 files changed

+109
-57
lines changed

Cargo.lock

Lines changed: 48 additions & 37 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,10 @@ time = { version = "0.3", features = ["macros", "serde"] }
133133
numpy = "0.25.0"
134134
infer = "0.19.0"
135135
serde_with = { version = "3.14.0", features = ["base64"] }
136-
google-cloud-aiplatform-v1 = { version = "0.4.4", default-features = false, features = [
136+
google-cloud-aiplatform-v1 = { version = "1.0.0", default-features = false, features = [
137137
"prediction-service",
138138
] }
139+
google-cloud-gax = "1.0.0"
139140

140141
azure_identity = { version = "0.21.0", default-features = false, features = [
141142
"enable_reqwest_rustls",

src/execution/live_updater.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ impl SourceUpdateTask {
144144
async move {
145145
let mut change_stream = change_stream;
146146
let retry_options = retryable::RetryOptions {
147-
max_retries: None,
147+
retry_timeout: std::time::Duration::from_secs(365 * 24 * 60 * 60),
148148
initial_backoff: std::time::Duration::from_secs(5),
149149
max_backoff: std::time::Duration::from_secs(60),
150150
};

src/llm/gemini.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ use crate::llm::{
66
};
77
use base64::prelude::*;
88
use google_cloud_aiplatform_v1 as vertexai;
9+
use google_cloud_gax::exponential_backoff::ExponentialBackoff;
10+
use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt};
11+
use google_cloud_gax::retry_throttler::{AdaptiveThrottler, SharedRetryThrottler};
912
use serde_json::Value;
1013
use urlencoding::encode;
1114

@@ -237,6 +240,36 @@ pub struct VertexAiClient {
237240
config: super::VertexAiConfig,
238241
}
239242

243+
#[derive(Debug)]
244+
struct CustomizedGoogleCloudRetryPolicy;
245+
246+
impl google_cloud_gax::retry_policy::RetryPolicy for CustomizedGoogleCloudRetryPolicy {
247+
fn on_error(
248+
&self,
249+
state: &google_cloud_gax::retry_state::RetryState,
250+
error: google_cloud_gax::error::Error,
251+
) -> google_cloud_gax::retry_result::RetryResult {
252+
use google_cloud_gax::retry_result::RetryResult;
253+
254+
if !state.idempotent {
255+
return RetryResult::Permanent(error);
256+
}
257+
if let Some(status) = error.status() {
258+
if status.code == google_cloud_gax::error::rpc::Code::ResourceExhausted {
259+
return RetryResult::Continue(error);
260+
}
261+
} else if let Some(code) = error.http_status_code()
262+
&& code == reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16()
263+
{
264+
return RetryResult::Continue(error);
265+
}
266+
Aip194Strict.on_error(state, error)
267+
}
268+
}
269+
270+
static SHARED_RETRY_THROTTLER: LazyLock<SharedRetryThrottler> =
271+
LazyLock::new(|| Arc::new(Mutex::new(AdaptiveThrottler::new(2.0).unwrap())));
272+
240273
impl VertexAiClient {
241274
pub async fn new(
242275
address: Option<String>,
@@ -249,6 +282,11 @@ impl VertexAiClient {
249282
api_bail!("VertexAi API config is required for VertexAi API type");
250283
};
251284
let client = vertexai::client::PredictionService::builder()
285+
.with_retry_policy(
286+
CustomizedGoogleCloudRetryPolicy.with_time_limit(retryable::DEFAULT_RETRY_TIMEOUT),
287+
)
288+
.with_backoff_policy(ExponentialBackoff::default())
289+
.with_retry_throttler(SHARED_RETRY_THROTTLER.clone())
252290
.build()
253291
.await?;
254292
Ok(Self { client, config })

0 commit comments

Comments
 (0)