|
| 1 | +use std::fmt; |
| 2 | + |
| 3 | +use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext; |
| 4 | +use aws_smithy_runtime_api::client::retries::classifiers::{ |
| 5 | + ClassifyRetry, |
| 6 | + RetryAction, |
| 7 | + RetryClassifierPriority, |
| 8 | +}; |
| 9 | +use tracing::debug; |
| 10 | + |
| 11 | +/// Error marker for monthly limit exceeded errors |
| 12 | +const MONTHLY_LIMIT_ERROR_MARKER: &str = "MONTHLY_REQUEST_COUNT"; |
| 13 | + |
| 14 | +/// Error message for high load conditions that should be retried |
| 15 | +const HIGH_LOAD_ERROR_MESSAGE: &str = |
| 16 | + "Encountered unexpectedly high load when processing the request, please try again."; |
| 17 | + |
| 18 | +/// Error message for insufficient model capacity that should be retried |
| 19 | +const INSUFFICIENT_MODEL_CAPACITY_MESSAGE: &str = "I am experiencing high traffic, please try again shortly."; |
| 20 | + |
| 21 | +/// Status codes that indicate service overload/unavailability and should be retried |
| 22 | +const SERVICE_OVERLOAD_STATUS_CODES: &[u16] = &[ |
| 23 | + 429, // Too Many Requests - throttling with insufficient model capacity |
| 24 | + 500, // Internal Server Error - requires specific message check for high load conditions |
| 25 | + 503, // Service Unavailable - server is temporarily overloaded or under maintenance |
| 26 | +]; |
| 27 | + |
| 28 | +#[derive(Debug, Default)] |
| 29 | +pub struct QCliRetryClassifier; |
| 30 | + |
| 31 | +impl QCliRetryClassifier { |
| 32 | + pub fn new() -> Self { |
| 33 | + Self |
| 34 | + } |
| 35 | + |
| 36 | + /// Return the priority of this retry classifier. |
| 37 | + /// |
| 38 | + /// We want this to run after the standard classifiers but with high priority |
| 39 | + /// to override their decisions for our specific error cases. |
| 40 | + /// |
| 41 | + /// # Returns |
| 42 | + /// A priority that runs after the transient error classifier but can override its decisions. |
| 43 | + pub fn priority() -> RetryClassifierPriority { |
| 44 | + RetryClassifierPriority::run_after(RetryClassifierPriority::transient_error_classifier()) |
| 45 | + } |
| 46 | + |
| 47 | + /// Check if the error indicates a monthly limit has been reached |
| 48 | + fn is_monthly_limit_error(ctx: &InterceptorContext) -> bool { |
| 49 | + let Some(resp) = ctx.response() else { |
| 50 | + return false; |
| 51 | + }; |
| 52 | + |
| 53 | + // Check status code first - monthly limit errors typically return 429 (Too Many Requests) |
| 54 | + let status_code = resp.status().as_u16(); |
| 55 | + if status_code != 429 { |
| 56 | + return false; |
| 57 | + } |
| 58 | + |
| 59 | + let Some(bytes) = resp.body().bytes() else { |
| 60 | + return false; |
| 61 | + }; |
| 62 | + |
| 63 | + let is_monthly_limit = match std::str::from_utf8(bytes) { |
| 64 | + Ok(body_str) => body_str.contains(MONTHLY_LIMIT_ERROR_MARKER), |
| 65 | + Err(_) => false, |
| 66 | + }; |
| 67 | + |
| 68 | + debug!( |
| 69 | + "QCliRetryClassifier: Monthly limit error detected: {}", |
| 70 | + is_monthly_limit |
| 71 | + ); |
| 72 | + is_monthly_limit |
| 73 | + } |
| 74 | + |
| 75 | + /// Check if the error indicates a model is unavailable due to high load |
| 76 | + fn is_service_overloaded_error(ctx: &InterceptorContext) -> bool { |
| 77 | + let Some(resp) = ctx.response() else { |
| 78 | + return false; |
| 79 | + }; |
| 80 | + |
| 81 | + let status_code = resp.status().as_u16(); |
| 82 | + |
| 83 | + // Fail fast: if status code is not in our list, return false immediately |
| 84 | + if !SERVICE_OVERLOAD_STATUS_CODES.contains(&status_code) { |
| 85 | + return false; |
| 86 | + } |
| 87 | + |
| 88 | + let is_overloaded = match status_code { |
| 89 | + 429 => { |
| 90 | + // For 429 errors, check if the response body contains the insufficient model capacity message |
| 91 | + let Some(bytes) = resp.body().bytes() else { |
| 92 | + return false; |
| 93 | + }; |
| 94 | + |
| 95 | + match std::str::from_utf8(bytes) { |
| 96 | + Ok(body_str) => body_str.contains(INSUFFICIENT_MODEL_CAPACITY_MESSAGE), |
| 97 | + Err(_) => false, |
| 98 | + } |
| 99 | + }, |
| 100 | + 500 => { |
| 101 | + // For 500 errors, check if the response body contains the specific high load message |
| 102 | + let Some(bytes) = resp.body().bytes() else { |
| 103 | + return false; |
| 104 | + }; |
| 105 | + |
| 106 | + match std::str::from_utf8(bytes) { |
| 107 | + Ok(body_str) => body_str.contains(HIGH_LOAD_ERROR_MESSAGE), |
| 108 | + Err(_) => false, |
| 109 | + } |
| 110 | + }, |
| 111 | + 503 => { |
| 112 | + // For 503 Service Unavailable, always retry (no additional checks needed) |
| 113 | + true |
| 114 | + }, |
| 115 | + _ => { |
| 116 | + // This shouldn't happen given our fail-fast check above, but handle gracefully |
| 117 | + false |
| 118 | + }, |
| 119 | + }; |
| 120 | + |
| 121 | + debug!( |
| 122 | + "QCliRetryClassifier: Service overloaded error detected (status {}): {}", |
| 123 | + status_code, is_overloaded |
| 124 | + ); |
| 125 | + is_overloaded |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +impl ClassifyRetry for QCliRetryClassifier { |
| 130 | + fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction { |
| 131 | + // Check for monthly limit error first - this should never be retried |
| 132 | + if Self::is_monthly_limit_error(ctx) { |
| 133 | + return RetryAction::RetryForbidden; |
| 134 | + } |
| 135 | + |
| 136 | + // Check for service overloaded error - this should be treated as throttling |
| 137 | + if Self::is_service_overloaded_error(ctx) { |
| 138 | + return RetryAction::throttling_error(); |
| 139 | + } |
| 140 | + |
| 141 | + // No specific action for other errors |
| 142 | + RetryAction::NoActionIndicated |
| 143 | + } |
| 144 | + |
| 145 | + fn name(&self) -> &'static str { |
| 146 | + "Q CLI Custom Retry Classifier" |
| 147 | + } |
| 148 | + |
| 149 | + fn priority(&self) -> RetryClassifierPriority { |
| 150 | + Self::priority() |
| 151 | + } |
| 152 | +} |
| 153 | + |
| 154 | +impl fmt::Display for QCliRetryClassifier { |
| 155 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 156 | + write!(f, "QCliRetryClassifier") |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +#[cfg(test)] |
| 161 | +mod tests { |
| 162 | + use aws_smithy_runtime_api::client::interceptors::context::{ |
| 163 | + Input, |
| 164 | + InterceptorContext, |
| 165 | + }; |
| 166 | + use aws_smithy_types::body::SdkBody; |
| 167 | + use http::Response; |
| 168 | + |
| 169 | + use super::*; |
| 170 | + |
| 171 | + #[test] |
| 172 | + fn test_monthly_limit_error_classification() { |
| 173 | + let classifier = QCliRetryClassifier::new(); |
| 174 | + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); |
| 175 | + |
| 176 | + // Create a response with MONTHLY_REQUEST_COUNT in the body |
| 177 | + let response_body = r#"{"error": "MONTHLY_REQUEST_COUNT exceeded"}"#; |
| 178 | + let response = Response::builder() |
| 179 | + .status(429) |
| 180 | + .body(response_body) |
| 181 | + .unwrap() |
| 182 | + .map(SdkBody::from); |
| 183 | + |
| 184 | + ctx.set_response(response.try_into().unwrap()); |
| 185 | + |
| 186 | + let result = classifier.classify_retry(&ctx); |
| 187 | + assert_eq!(result, RetryAction::RetryForbidden); |
| 188 | + } |
| 189 | + |
| 190 | + #[test] |
| 191 | + fn test_insufficient_model_capacity_error_classification() { |
| 192 | + let classifier = QCliRetryClassifier::new(); |
| 193 | + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); |
| 194 | + |
| 195 | + // Create a 429 response with the insufficient model capacity message - should be treated as service |
| 196 | + // overloaded |
| 197 | + let response_body = r#"{"error": "I am experiencing high traffic, please try again shortly."}"#; |
| 198 | + let response = Response::builder() |
| 199 | + .status(429) |
| 200 | + .body(response_body) |
| 201 | + .unwrap() |
| 202 | + .map(SdkBody::from); |
| 203 | + |
| 204 | + ctx.set_response(response.try_into().unwrap()); |
| 205 | + |
| 206 | + let result = classifier.classify_retry(&ctx); |
| 207 | + assert_eq!(result, RetryAction::throttling_error()); |
| 208 | + } |
| 209 | + |
| 210 | + #[test] |
| 211 | + fn test_429_error_without_insufficient_capacity_message_not_retried() { |
| 212 | + let classifier = QCliRetryClassifier::new(); |
| 213 | + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); |
| 214 | + |
| 215 | + // Create a 429 response without the specific insufficient model capacity message - should NOT be |
| 216 | + // retried |
| 217 | + let response_body = "Too Many Requests - some other error"; |
| 218 | + let response = Response::builder() |
| 219 | + .status(429) |
| 220 | + .body(response_body) |
| 221 | + .unwrap() |
| 222 | + .map(SdkBody::from); |
| 223 | + |
| 224 | + ctx.set_response(response.try_into().unwrap()); |
| 225 | + |
| 226 | + let result = classifier.classify_retry(&ctx); |
| 227 | + assert_eq!(result, RetryAction::NoActionIndicated); |
| 228 | + } |
| 229 | + |
| 230 | + #[test] |
| 231 | + fn test_service_overloaded_error_classification() { |
| 232 | + let classifier = QCliRetryClassifier::new(); |
| 233 | + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); |
| 234 | + |
| 235 | + // Create a 500 response with the specific high load message - should be treated as service |
| 236 | + // overloaded |
| 237 | + let response_body = |
| 238 | + r#"{"error": "Encountered unexpectedly high load when processing the request, please try again."}"#; |
| 239 | + let response = Response::builder() |
| 240 | + .status(500) |
| 241 | + .body(response_body) |
| 242 | + .unwrap() |
| 243 | + .map(SdkBody::from); |
| 244 | + |
| 245 | + ctx.set_response(response.try_into().unwrap()); |
| 246 | + |
| 247 | + let result = classifier.classify_retry(&ctx); |
| 248 | + assert_eq!(result, RetryAction::throttling_error()); |
| 249 | + } |
| 250 | + |
| 251 | + #[test] |
| 252 | + fn test_500_error_without_high_load_message_not_retried() { |
| 253 | + let classifier = QCliRetryClassifier::new(); |
| 254 | + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); |
| 255 | + |
| 256 | + // Create a 500 response without the specific high load message - should NOT be retried |
| 257 | + let response_body = "Internal Server Error - some other error"; |
| 258 | + let response = Response::builder() |
| 259 | + .status(500) |
| 260 | + .body(response_body) |
| 261 | + .unwrap() |
| 262 | + .map(SdkBody::from); |
| 263 | + |
| 264 | + ctx.set_response(response.try_into().unwrap()); |
| 265 | + |
| 266 | + let result = classifier.classify_retry(&ctx); |
| 267 | + assert_eq!(result, RetryAction::NoActionIndicated); |
| 268 | + } |
| 269 | + |
| 270 | + #[test] |
| 271 | + fn test_service_unavailable_error_classification() { |
| 272 | + let classifier = QCliRetryClassifier::new(); |
| 273 | + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); |
| 274 | + |
| 275 | + // Create a 503 response - should be treated as service overloaded |
| 276 | + let response_body = "Service Unavailable"; |
| 277 | + let response = Response::builder() |
| 278 | + .status(503) |
| 279 | + .body(response_body) |
| 280 | + .unwrap() |
| 281 | + .map(SdkBody::from); |
| 282 | + |
| 283 | + ctx.set_response(response.try_into().unwrap()); |
| 284 | + |
| 285 | + let result = classifier.classify_retry(&ctx); |
| 286 | + assert_eq!(result, RetryAction::throttling_error()); |
| 287 | + } |
| 288 | + |
| 289 | + #[test] |
| 290 | + fn test_no_action_for_non_overload_errors() { |
| 291 | + let classifier = QCliRetryClassifier::new(); |
| 292 | + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); |
| 293 | + |
| 294 | + // Create a 400 response - should not be treated as service overloaded |
| 295 | + let response = Response::builder() |
| 296 | + .status(400) |
| 297 | + .body("Bad Request") |
| 298 | + .unwrap() |
| 299 | + .map(SdkBody::from); |
| 300 | + |
| 301 | + ctx.set_response(response.try_into().unwrap()); |
| 302 | + |
| 303 | + let result = classifier.classify_retry(&ctx); |
| 304 | + assert_eq!(result, RetryAction::NoActionIndicated); |
| 305 | + } |
| 306 | + |
| 307 | + #[test] |
| 308 | + fn test_fail_fast_for_non_service_overload_status_codes() { |
| 309 | + let classifier = QCliRetryClassifier::new(); |
| 310 | + let mut ctx = InterceptorContext::new(Input::doesnt_matter()); |
| 311 | + |
| 312 | + // Test various status codes that are not in SERVICE_OVERLOAD_STATUS_CODES |
| 313 | + let test_cases = vec![ |
| 314 | + (200, "OK"), |
| 315 | + (400, "Bad Request"), |
| 316 | + (401, "Unauthorized"), |
| 317 | + (403, "Forbidden"), |
| 318 | + (404, "Not Found"), |
| 319 | + (502, "Bad Gateway"), |
| 320 | + ]; |
| 321 | + |
| 322 | + for (status_code, body) in test_cases { |
| 323 | + let response = Response::builder() |
| 324 | + .status(status_code) |
| 325 | + .body(body) |
| 326 | + .unwrap() |
| 327 | + .map(SdkBody::from); |
| 328 | + |
| 329 | + ctx.set_response(response.try_into().unwrap()); |
| 330 | + |
| 331 | + let result = classifier.classify_retry(&ctx); |
| 332 | + assert_eq!( |
| 333 | + result, |
| 334 | + RetryAction::NoActionIndicated, |
| 335 | + "Status code {} should return NoActionIndicated", |
| 336 | + status_code |
| 337 | + ); |
| 338 | + } |
| 339 | + } |
| 340 | + |
| 341 | + #[test] |
| 342 | + fn test_classifier_priority() { |
| 343 | + let priority = QCliRetryClassifier::priority(); |
| 344 | + let transient_priority = RetryClassifierPriority::transient_error_classifier(); |
| 345 | + |
| 346 | + // Our classifier should have higher priority than the transient error classifier |
| 347 | + assert!(priority > transient_priority); |
| 348 | + } |
| 349 | + |
| 350 | + #[test] |
| 351 | + fn test_classifier_name() { |
| 352 | + let classifier = QCliRetryClassifier::new(); |
| 353 | + assert_eq!(classifier.name(), "Q CLI Custom Retry Classifier"); |
| 354 | + } |
| 355 | +} |
0 commit comments