Skip to content

Commit 1cb013d

Browse files
authored
feat(auto-retry): preliminary support for auto retry on 429 (#721)
* feat(auto-retry): preliminary support for auto retry on 429 * style: simplification of async closure * fix: bring back status check, default options
1 parent 640fc21 commit 1cb013d

File tree

5 files changed

+65
-49
lines changed

5 files changed

+65
-49
lines changed

src/llm/anthropic.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1+
use crate::prelude::*;
2+
use base64::prelude::*;
3+
14
use crate::llm::{
25
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, OutputFormat,
36
ToJsonSchemaOptions, detect_image_mime_type,
47
};
5-
use anyhow::{Context, Result, bail};
6-
use async_trait::async_trait;
7-
use base64::prelude::*;
8+
use anyhow::Context;
89
use json5;
9-
use serde_json::Value;
10-
11-
use crate::api_bail;
1210
use urlencoding::encode;
1311

1412
pub struct Client {
@@ -91,23 +89,26 @@ impl LlmGenerationClient for Client {
9189

9290
let encoded_api_key = encode(&self.api_key);
9391

94-
let resp = self
95-
.client
96-
.post(url)
97-
.header("x-api-key", encoded_api_key.as_ref())
98-
.header("anthropic-version", "2023-06-01")
99-
.json(&payload)
100-
.send()
101-
.await
102-
.context("HTTP error")?;
92+
let resp = retryable::run(
93+
|| {
94+
self.client
95+
.post(url)
96+
.header("x-api-key", encoded_api_key.as_ref())
97+
.header("anthropic-version", "2023-06-01")
98+
.json(&payload)
99+
.send()
100+
},
101+
&retryable::HEAVY_LOADED_OPTIONS,
102+
)
103+
.await?;
103104
if !resp.status().is_success() {
104105
bail!(
105106
"Anthropic API error: {:?}\n{}\n",
106107
resp.status(),
107108
resp.text().await?
108109
);
109110
}
110-
let mut resp_json: Value = resp.json().await.context("Invalid JSON")?;
111+
let mut resp_json: serde_json::Value = resp.json().await.context("Invalid JSON")?;
111112
if let Some(error) = resp_json.get("error") {
112113
bail!("Anthropic API error: {:?}", error);
113114
}
@@ -117,11 +118,11 @@ impl LlmGenerationClient for Client {
117118

118119
let resp_content = &resp_json["content"];
119120
let tool_name = "report_result";
120-
let mut extracted_json: Option<Value> = None;
121+
let mut extracted_json: Option<serde_json::Value> = None;
121122
if let Some(array) = resp_content.as_array() {
122123
for item in array {
123-
if item.get("type") == Some(&Value::String("tool_use".to_string()))
124-
&& item.get("name") == Some(&Value::String(tool_name.to_string()))
124+
if item.get("type") == Some(&serde_json::Value::String("tool_use".to_string()))
125+
&& item.get("name") == Some(&serde_json::Value::String(tool_name.to_string()))
125126
{
126127
if let Some(input) = item.get("input") {
127128
extracted_json = Some(input.clone());
@@ -136,7 +137,7 @@ impl LlmGenerationClient for Client {
136137
} else {
137138
// Fallback: try text if no tool output found
138139
match &mut resp_json["content"][0]["text"] {
139-
Value::String(s) => {
140+
serde_json::Value::String(s) => {
140141
// Try strict JSON parsing first
141142
match serde_json::from_str::<serde_json::Value>(s) {
142143
Ok(_) => std::mem::take(s),

src/llm/gemini.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,11 @@ impl LlmGenerationClient for Client {
113113
}
114114

115115
let url = self.get_api_url(request.model, "generateContent");
116-
let resp = self
117-
.client
118-
.post(&url)
119-
.json(&payload)
120-
.send()
121-
.await
122-
.context("HTTP error")?;
116+
let resp = retryable::run(
117+
|| self.client.post(&url).json(&payload).send(),
118+
&retryable::HEAVY_LOADED_OPTIONS,
119+
)
120+
.await?;
123121
if !resp.status().is_success() {
124122
bail!(
125123
"Gemini API error: {:?}\n{}\n",
@@ -174,13 +172,11 @@ impl LlmEmbeddingClient for Client {
174172
if let Some(task_type) = request.task_type {
175173
payload["taskType"] = serde_json::Value::String(task_type.into());
176174
}
177-
let resp = self
178-
.client
179-
.post(&url)
180-
.json(&payload)
181-
.send()
182-
.await
183-
.context("HTTP error")?;
175+
let resp = retryable::run(
176+
|| self.client.post(&url).json(&payload).send(),
177+
&retryable::HEAVY_LOADED_OPTIONS,
178+
)
179+
.await?;
184180
if !resp.status().is_success() {
185181
bail!(
186182
"Gemini API error: {:?}\n{}\n",

src/llm/ollama.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,16 @@ impl LlmGenerationClient for Client {
6565
system: request.system_prompt.as_ref().map(|s| s.as_ref()),
6666
stream: Some(false),
6767
};
68-
let res = self
69-
.reqwest_client
70-
.post(self.generate_url.as_str())
71-
.json(&req)
72-
.send()
73-
.await?;
68+
let res = retryable::run(
69+
|| {
70+
self.reqwest_client
71+
.post(self.generate_url.as_str())
72+
.json(&req)
73+
.send()
74+
},
75+
&retryable::HEAVY_LOADED_OPTIONS,
76+
)
77+
.await?;
7478
if !res.status().is_success() {
7579
bail!(
7680
"Ollama API error: {:?}\n{}\n",

src/llm/voyage.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,17 @@ impl LlmEmbeddingClient for Client {
7575
payload["input_type"] = serde_json::Value::String(task_type.into());
7676
}
7777

78-
let resp = self
79-
.client
80-
.post(url)
81-
.header("Authorization", format!("Bearer {}", self.api_key))
82-
.json(&payload)
83-
.send()
84-
.await
85-
.context("HTTP error")?;
78+
let resp = retryable::run(
79+
|| {
80+
self.client
81+
.post(url)
82+
.header("Authorization", format!("Bearer {}", self.api_key))
83+
.json(&payload)
84+
.send()
85+
},
86+
&retryable::HEAVY_LOADED_OPTIONS,
87+
)
88+
.await?;
8689

8790
if !resp.status().is_success() {
8891
bail!(

src/utils/retryable.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ impl IsRetryable for Error {
2828
}
2929
}
3030

31+
impl IsRetryable for reqwest::Error {
32+
fn is_retryable(&self) -> bool {
33+
self.status() == Some(reqwest::StatusCode::TOO_MANY_REQUESTS)
34+
}
35+
}
36+
3137
impl Error {
3238
pub fn always_retryable(error: anyhow::Error) -> Self {
3339
Self {
@@ -77,13 +83,19 @@ pub struct RetryOptions {
7783
impl Default for RetryOptions {
7884
fn default() -> Self {
7985
Self {
80-
max_retries: Some(5),
86+
max_retries: Some(10),
8187
initial_backoff: Duration::from_millis(100),
8288
max_backoff: Duration::from_secs(10),
8389
}
8490
}
8591
}
8692

93+
pub static HEAVY_LOADED_OPTIONS: RetryOptions = RetryOptions {
94+
max_retries: Some(10),
95+
initial_backoff: Duration::from_secs(1),
96+
max_backoff: Duration::from_secs(60),
97+
};
98+
8799
pub async fn run<
88100
Ok,
89101
Err: std::fmt::Display + IsRetryable,

0 commit comments

Comments
 (0)