Skip to content

Commit ccc6b33

Browse files
authored
feat: add delay tracking interceptor for retry notifications (#2607)
1 parent 0fa746f commit ccc6b33

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
use std::time::{
2+
Duration,
3+
Instant,
4+
};
5+
6+
use aws_smithy_runtime_api::box_error::BoxError;
7+
use aws_smithy_runtime_api::client::interceptors::Intercept;
8+
use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextRef;
9+
use aws_smithy_runtime_api::client::retries::RequestAttempts;
10+
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
11+
use aws_smithy_types::config_bag::{
12+
ConfigBag,
13+
Storable,
14+
StoreReplace,
15+
};
16+
use crossterm::style::Color;
17+
use crossterm::{
18+
execute,
19+
style,
20+
};
21+
22+
#[derive(Debug, Clone)]
23+
pub struct DelayTrackingInterceptor {
24+
minor_delay_threshold: Duration,
25+
major_delay_threshold: Duration,
26+
}
27+
28+
impl DelayTrackingInterceptor {
29+
pub fn new() -> Self {
30+
Self {
31+
minor_delay_threshold: Duration::from_secs(2),
32+
major_delay_threshold: Duration::from_secs(5),
33+
}
34+
}
35+
36+
fn print_warning(message: String) {
37+
let mut stderr = std::io::stderr();
38+
let _ = execute!(
39+
stderr,
40+
style::SetForegroundColor(Color::Yellow),
41+
style::Print("\nWARNING: "),
42+
style::SetForegroundColor(Color::Reset),
43+
style::Print(message),
44+
style::Print("\n")
45+
);
46+
}
47+
}
48+
49+
impl Intercept for DelayTrackingInterceptor {
50+
fn name(&self) -> &'static str {
51+
"DelayTrackingInterceptor"
52+
}
53+
54+
fn read_before_transmit(
55+
&self,
56+
_: &BeforeTransmitInterceptorContextRef<'_>,
57+
_: &RuntimeComponents,
58+
cfg: &mut ConfigBag,
59+
) -> Result<(), BoxError> {
60+
let attempt_number = cfg.load::<RequestAttempts>().map_or(1, |attempts| attempts.attempts());
61+
62+
let now = Instant::now();
63+
64+
if let Some(last_attempt_time) = cfg.load::<LastAttemptTime>() {
65+
let delay = now.duration_since(last_attempt_time.0);
66+
67+
if delay >= self.major_delay_threshold {
68+
Self::print_warning(format!(
69+
"Auto Retry #{} delayed by {:.1}s. Service is under heavy load - consider switching models.",
70+
attempt_number,
71+
delay.as_secs_f64()
72+
));
73+
} else if delay >= self.minor_delay_threshold {
74+
Self::print_warning(format!(
75+
"Auto Retry #{} delayed by {:.1}s due to transient issues.",
76+
attempt_number,
77+
delay.as_secs_f64()
78+
));
79+
}
80+
}
81+
82+
cfg.interceptor_state().store_put(LastAttemptTime(Instant::now()));
83+
Ok(())
84+
}
85+
}
86+
87+
#[derive(Debug, Clone)]
88+
struct LastAttemptTime(Instant);
89+
90+
impl Storable for LastAttemptTime {
91+
type Storer = StoreReplace<Self>;
92+
}

crates/chat-cli/src/api_client/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod credentials;
22
pub mod customization;
3+
mod delay_interceptor;
34
mod endpoints;
45
mod error;
56
pub mod model;
@@ -41,6 +42,7 @@ use tracing::{
4142
};
4243

4344
use crate::api_client::credentials::CredentialsChain;
45+
use crate::api_client::delay_interceptor::DelayTrackingInterceptor;
4446
use crate::api_client::model::{
4547
ChatResponseStream,
4648
ConversationState,
@@ -163,6 +165,7 @@ impl ApiClient {
163165
.http_client(crate::aws_common::http_client::client())
164166
.interceptor(OptOutInterceptor::new(database))
165167
.interceptor(UserAgentOverrideInterceptor::new())
168+
.interceptor(DelayTrackingInterceptor::new())
166169
.app_name(app_name())
167170
.endpoint_url(endpoint.url())
168171
.retry_classifier(retry_classifier::QCliRetryClassifier::new())
@@ -176,6 +179,7 @@ impl ApiClient {
176179
.http_client(crate::aws_common::http_client::client())
177180
.interceptor(OptOutInterceptor::new(database))
178181
.interceptor(UserAgentOverrideInterceptor::new())
182+
.interceptor(DelayTrackingInterceptor::new())
179183
.bearer_token_resolver(BearerResolver)
180184
.app_name(app_name())
181185
.endpoint_url(endpoint.url())

0 commit comments

Comments
 (0)