Skip to content

Commit 5ba7b8d

Browse files
authored
Implement Retry Classifier for Adaptive Rate Limiting #2432
1 parent 6e18c15 commit 5ba7b8d

File tree

2 files changed

+361
-2
lines changed

2 files changed

+361
-2
lines changed

crates/chat-cli/src/api_client/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ mod error;
55
pub mod model;
66
mod opt_out;
77
pub mod profile;
8+
mod retry_classifier;
89
pub mod send_message_output;
9-
1010
use std::sync::Arc;
1111
use std::time::Duration;
1212

@@ -146,6 +146,7 @@ impl ApiClient {
146146
.interceptor(UserAgentOverrideInterceptor::new())
147147
.app_name(app_name())
148148
.endpoint_url(endpoint.url())
149+
.retry_classifier(retry_classifier::QCliRetryClassifier::new())
149150
.stalled_stream_protection(stalled_stream_protection_config())
150151
.build(),
151152
));
@@ -159,6 +160,7 @@ impl ApiClient {
159160
.bearer_token_resolver(BearerResolver)
160161
.app_name(app_name())
161162
.endpoint_url(endpoint.url())
163+
.retry_classifier(retry_classifier::QCliRetryClassifier::new())
162164
.stalled_stream_protection(stalled_stream_protection_config())
163165
.build(),
164166
));
@@ -496,7 +498,9 @@ fn timeout_config(database: &Database) -> TimeoutConfig {
496498
}
497499

498500
fn retry_config() -> RetryConfig {
499-
RetryConfig::standard().with_max_attempts(1)
501+
RetryConfig::adaptive()
502+
.with_max_attempts(3)
503+
.with_max_backoff(Duration::from_secs(10))
500504
}
501505

502506
pub fn stalled_stream_protection_config() -> StalledStreamProtectionConfig {
Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
use std::fmt;
2+
3+
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
4+
use aws_smithy_runtime_api::client::retries::classifiers::{
5+
ClassifyRetry,
6+
RetryAction,
7+
RetryClassifierPriority,
8+
};
9+
use tracing::debug;
10+
11+
/// Error marker for monthly limit exceeded errors
12+
const MONTHLY_LIMIT_ERROR_MARKER: &str = "MONTHLY_REQUEST_COUNT";
13+
14+
/// Error message for high load conditions that should be retried
15+
const HIGH_LOAD_ERROR_MESSAGE: &str =
16+
"Encountered unexpectedly high load when processing the request, please try again.";
17+
18+
/// Error message for insufficient model capacity that should be retried
19+
const INSUFFICIENT_MODEL_CAPACITY_MESSAGE: &str = "I am experiencing high traffic, please try again shortly.";
20+
21+
/// Status codes that indicate service overload/unavailability and should be retried
22+
const SERVICE_OVERLOAD_STATUS_CODES: &[u16] = &[
23+
429, // Too Many Requests - throttling with insufficient model capacity
24+
500, // Internal Server Error - requires specific message check for high load conditions
25+
503, // Service Unavailable - server is temporarily overloaded or under maintenance
26+
];
27+
28+
#[derive(Debug, Default)]
29+
pub struct QCliRetryClassifier;
30+
31+
impl QCliRetryClassifier {
32+
pub fn new() -> Self {
33+
Self
34+
}
35+
36+
/// Return the priority of this retry classifier.
37+
///
38+
/// We want this to run after the standard classifiers but with high priority
39+
/// to override their decisions for our specific error cases.
40+
///
41+
/// # Returns
42+
/// A priority that runs after the transient error classifier but can override its decisions.
43+
pub fn priority() -> RetryClassifierPriority {
44+
RetryClassifierPriority::run_after(RetryClassifierPriority::transient_error_classifier())
45+
}
46+
47+
/// Check if the error indicates a monthly limit has been reached
48+
fn is_monthly_limit_error(ctx: &InterceptorContext) -> bool {
49+
let Some(resp) = ctx.response() else {
50+
return false;
51+
};
52+
53+
// Check status code first - monthly limit errors typically return 429 (Too Many Requests)
54+
let status_code = resp.status().as_u16();
55+
if status_code != 429 {
56+
return false;
57+
}
58+
59+
let Some(bytes) = resp.body().bytes() else {
60+
return false;
61+
};
62+
63+
let is_monthly_limit = match std::str::from_utf8(bytes) {
64+
Ok(body_str) => body_str.contains(MONTHLY_LIMIT_ERROR_MARKER),
65+
Err(_) => false,
66+
};
67+
68+
debug!(
69+
"QCliRetryClassifier: Monthly limit error detected: {}",
70+
is_monthly_limit
71+
);
72+
is_monthly_limit
73+
}
74+
75+
/// Check if the error indicates a model is unavailable due to high load
76+
fn is_service_overloaded_error(ctx: &InterceptorContext) -> bool {
77+
let Some(resp) = ctx.response() else {
78+
return false;
79+
};
80+
81+
let status_code = resp.status().as_u16();
82+
83+
// Fail fast: if status code is not in our list, return false immediately
84+
if !SERVICE_OVERLOAD_STATUS_CODES.contains(&status_code) {
85+
return false;
86+
}
87+
88+
let is_overloaded = match status_code {
89+
429 => {
90+
// For 429 errors, check if the response body contains the insufficient model capacity message
91+
let Some(bytes) = resp.body().bytes() else {
92+
return false;
93+
};
94+
95+
match std::str::from_utf8(bytes) {
96+
Ok(body_str) => body_str.contains(INSUFFICIENT_MODEL_CAPACITY_MESSAGE),
97+
Err(_) => false,
98+
}
99+
},
100+
500 => {
101+
// For 500 errors, check if the response body contains the specific high load message
102+
let Some(bytes) = resp.body().bytes() else {
103+
return false;
104+
};
105+
106+
match std::str::from_utf8(bytes) {
107+
Ok(body_str) => body_str.contains(HIGH_LOAD_ERROR_MESSAGE),
108+
Err(_) => false,
109+
}
110+
},
111+
503 => {
112+
// For 503 Service Unavailable, always retry (no additional checks needed)
113+
true
114+
},
115+
_ => {
116+
// This shouldn't happen given our fail-fast check above, but handle gracefully
117+
false
118+
},
119+
};
120+
121+
debug!(
122+
"QCliRetryClassifier: Service overloaded error detected (status {}): {}",
123+
status_code, is_overloaded
124+
);
125+
is_overloaded
126+
}
127+
}
128+
129+
impl ClassifyRetry for QCliRetryClassifier {
130+
fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
131+
// Check for monthly limit error first - this should never be retried
132+
if Self::is_monthly_limit_error(ctx) {
133+
return RetryAction::RetryForbidden;
134+
}
135+
136+
// Check for service overloaded error - this should be treated as throttling
137+
if Self::is_service_overloaded_error(ctx) {
138+
return RetryAction::throttling_error();
139+
}
140+
141+
// No specific action for other errors
142+
RetryAction::NoActionIndicated
143+
}
144+
145+
fn name(&self) -> &'static str {
146+
"Q CLI Custom Retry Classifier"
147+
}
148+
149+
fn priority(&self) -> RetryClassifierPriority {
150+
Self::priority()
151+
}
152+
}
153+
154+
impl fmt::Display for QCliRetryClassifier {
155+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156+
write!(f, "QCliRetryClassifier")
157+
}
158+
}
159+
160+
#[cfg(test)]
161+
mod tests {
162+
use aws_smithy_runtime_api::client::interceptors::context::{
163+
Input,
164+
InterceptorContext,
165+
};
166+
use aws_smithy_types::body::SdkBody;
167+
use http::Response;
168+
169+
use super::*;
170+
171+
#[test]
172+
fn test_monthly_limit_error_classification() {
173+
let classifier = QCliRetryClassifier::new();
174+
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
175+
176+
// Create a response with MONTHLY_REQUEST_COUNT in the body
177+
let response_body = r#"{"error": "MONTHLY_REQUEST_COUNT exceeded"}"#;
178+
let response = Response::builder()
179+
.status(429)
180+
.body(response_body)
181+
.unwrap()
182+
.map(SdkBody::from);
183+
184+
ctx.set_response(response.try_into().unwrap());
185+
186+
let result = classifier.classify_retry(&ctx);
187+
assert_eq!(result, RetryAction::RetryForbidden);
188+
}
189+
190+
#[test]
191+
fn test_insufficient_model_capacity_error_classification() {
192+
let classifier = QCliRetryClassifier::new();
193+
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
194+
195+
// Create a 429 response with the insufficient model capacity message - should be treated as service
196+
// overloaded
197+
let response_body = r#"{"error": "I am experiencing high traffic, please try again shortly."}"#;
198+
let response = Response::builder()
199+
.status(429)
200+
.body(response_body)
201+
.unwrap()
202+
.map(SdkBody::from);
203+
204+
ctx.set_response(response.try_into().unwrap());
205+
206+
let result = classifier.classify_retry(&ctx);
207+
assert_eq!(result, RetryAction::throttling_error());
208+
}
209+
210+
#[test]
211+
fn test_429_error_without_insufficient_capacity_message_not_retried() {
212+
let classifier = QCliRetryClassifier::new();
213+
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
214+
215+
// Create a 429 response without the specific insufficient model capacity message - should NOT be
216+
// retried
217+
let response_body = "Too Many Requests - some other error";
218+
let response = Response::builder()
219+
.status(429)
220+
.body(response_body)
221+
.unwrap()
222+
.map(SdkBody::from);
223+
224+
ctx.set_response(response.try_into().unwrap());
225+
226+
let result = classifier.classify_retry(&ctx);
227+
assert_eq!(result, RetryAction::NoActionIndicated);
228+
}
229+
230+
#[test]
231+
fn test_service_overloaded_error_classification() {
232+
let classifier = QCliRetryClassifier::new();
233+
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
234+
235+
// Create a 500 response with the specific high load message - should be treated as service
236+
// overloaded
237+
let response_body =
238+
r#"{"error": "Encountered unexpectedly high load when processing the request, please try again."}"#;
239+
let response = Response::builder()
240+
.status(500)
241+
.body(response_body)
242+
.unwrap()
243+
.map(SdkBody::from);
244+
245+
ctx.set_response(response.try_into().unwrap());
246+
247+
let result = classifier.classify_retry(&ctx);
248+
assert_eq!(result, RetryAction::throttling_error());
249+
}
250+
251+
#[test]
252+
fn test_500_error_without_high_load_message_not_retried() {
253+
let classifier = QCliRetryClassifier::new();
254+
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
255+
256+
// Create a 500 response without the specific high load message - should NOT be retried
257+
let response_body = "Internal Server Error - some other error";
258+
let response = Response::builder()
259+
.status(500)
260+
.body(response_body)
261+
.unwrap()
262+
.map(SdkBody::from);
263+
264+
ctx.set_response(response.try_into().unwrap());
265+
266+
let result = classifier.classify_retry(&ctx);
267+
assert_eq!(result, RetryAction::NoActionIndicated);
268+
}
269+
270+
#[test]
271+
fn test_service_unavailable_error_classification() {
272+
let classifier = QCliRetryClassifier::new();
273+
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
274+
275+
// Create a 503 response - should be treated as service overloaded
276+
let response_body = "Service Unavailable";
277+
let response = Response::builder()
278+
.status(503)
279+
.body(response_body)
280+
.unwrap()
281+
.map(SdkBody::from);
282+
283+
ctx.set_response(response.try_into().unwrap());
284+
285+
let result = classifier.classify_retry(&ctx);
286+
assert_eq!(result, RetryAction::throttling_error());
287+
}
288+
289+
#[test]
290+
fn test_no_action_for_non_overload_errors() {
291+
let classifier = QCliRetryClassifier::new();
292+
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
293+
294+
// Create a 400 response - should not be treated as service overloaded
295+
let response = Response::builder()
296+
.status(400)
297+
.body("Bad Request")
298+
.unwrap()
299+
.map(SdkBody::from);
300+
301+
ctx.set_response(response.try_into().unwrap());
302+
303+
let result = classifier.classify_retry(&ctx);
304+
assert_eq!(result, RetryAction::NoActionIndicated);
305+
}
306+
307+
#[test]
308+
fn test_fail_fast_for_non_service_overload_status_codes() {
309+
let classifier = QCliRetryClassifier::new();
310+
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
311+
312+
// Test various status codes that are not in SERVICE_OVERLOAD_STATUS_CODES
313+
let test_cases = vec![
314+
(200, "OK"),
315+
(400, "Bad Request"),
316+
(401, "Unauthorized"),
317+
(403, "Forbidden"),
318+
(404, "Not Found"),
319+
(502, "Bad Gateway"),
320+
];
321+
322+
for (status_code, body) in test_cases {
323+
let response = Response::builder()
324+
.status(status_code)
325+
.body(body)
326+
.unwrap()
327+
.map(SdkBody::from);
328+
329+
ctx.set_response(response.try_into().unwrap());
330+
331+
let result = classifier.classify_retry(&ctx);
332+
assert_eq!(
333+
result,
334+
RetryAction::NoActionIndicated,
335+
"Status code {} should return NoActionIndicated",
336+
status_code
337+
);
338+
}
339+
}
340+
341+
#[test]
342+
fn test_classifier_priority() {
343+
let priority = QCliRetryClassifier::priority();
344+
let transient_priority = RetryClassifierPriority::transient_error_classifier();
345+
346+
// Our classifier should have higher priority than the transient error classifier
347+
assert!(priority > transient_priority);
348+
}
349+
350+
#[test]
351+
fn test_classifier_name() {
352+
let classifier = QCliRetryClassifier::new();
353+
assert_eq!(classifier.name(), "Q CLI Custom Retry Classifier");
354+
}
355+
}

0 commit comments

Comments
 (0)