Skip to content

Commit e7ea8af

Browse files
author
kiran-garre
committed
chore: Merging changes from main
2 parents 81b1dac + 754c4d5 commit e7ea8af

File tree

28 files changed

+1101
-294
lines changed

28 files changed

+1101
-294
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ authors = ["Amazon Q CLI Team ([email protected])", "Chay Nabors (nabochay@amazon
88
edition = "2024"
99
homepage = "https://aws.amazon.com/q/"
1010
publish = false
11-
version = "1.13.2"
11+
version = "1.13.3"
1212
license = "MIT OR Apache-2.0"
1313

1414
[workspace.dependencies]

crates/chat-cli/src/api_client/error.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError;
22
use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError;
33
use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError;
4+
use amzn_codewhisperer_client::operation::list_available_models::ListAvailableModelsError;
45
use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError;
56
use amzn_codewhisperer_client::operation::send_telemetry_event::SendTelemetryEventError;
67
pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError;
@@ -93,6 +94,12 @@ pub enum ApiClientError {
9394
// Credential errors
9495
#[error("failed to load credentials: {}", .0)]
9596
Credentials(CredentialsError),
97+
98+
#[error(transparent)]
99+
ListAvailableModelsError(#[from] SdkError<ListAvailableModelsError, HttpResponse>),
100+
101+
#[error("No default model found in the ListAvailableModels API response")]
102+
DefaultModelNotFound,
96103
}
97104

98105
impl ApiClientError {
@@ -116,6 +123,8 @@ impl ApiClientError {
116123
Self::ModelOverloadedError { status_code, .. } => *status_code,
117124
Self::MonthlyLimitReached { status_code } => *status_code,
118125
Self::Credentials(_e) => None,
126+
Self::ListAvailableModelsError(e) => sdk_status_code(e),
127+
Self::DefaultModelNotFound => None,
119128
}
120129
}
121130
}
@@ -141,6 +150,8 @@ impl ReasonCode for ApiClientError {
141150
Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(),
142151
Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(),
143152
Self::Credentials(_) => "CredentialsError".to_string(),
153+
Self::ListAvailableModelsError(e) => sdk_error_code(e),
154+
Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(),
144155
}
145156
}
146157
}
@@ -188,6 +199,10 @@ mod tests {
188199
ListAvailableCustomizationsError::unhandled("<unhandled>"),
189200
response(),
190201
)),
202+
ApiClientError::ListAvailableModelsError(SdkError::service_error(
203+
ListAvailableModelsError::unhandled("<unhandled>"),
204+
response(),
205+
)),
191206
ApiClientError::ListAvailableServices(SdkError::service_error(
192207
ListCustomizationsError::unhandled("<unhandled>"),
193208
response(),

crates/chat-cli/src/api_client/mod.rs

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ mod error;
55
pub mod model;
66
mod opt_out;
77
pub mod profile;
8+
mod retry_classifier;
89
pub mod send_message_output;
9-
1010
use std::sync::Arc;
1111
use std::time::Duration;
1212

1313
use amzn_codewhisperer_client::Client as CodewhispererClient;
1414
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput;
15+
use amzn_codewhisperer_client::types::Origin::Cli;
1516
use amzn_codewhisperer_client::types::{
17+
Model,
1618
OptOutPreference,
1719
SubscriptionStatus,
1820
TelemetryEvent,
@@ -32,6 +34,7 @@ pub use error::ApiClientError;
3234
use parking_lot::Mutex;
3335
pub use profile::list_available_profiles;
3436
use serde_json::Map;
37+
use tokio::sync::RwLock;
3538
use tracing::{
3639
debug,
3740
error,
@@ -66,13 +69,28 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
6669
// TODO(bskiser): confirm timeout is updated to an appropriate value?
6770
const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5);
6871

72+
#[derive(Clone, Debug)]
73+
pub struct ModelListResult {
74+
pub models: Vec<Model>,
75+
pub default_model: Model,
76+
}
77+
78+
impl From<ModelListResult> for (Vec<Model>, Model) {
79+
fn from(v: ModelListResult) -> Self {
80+
(v.models, v.default_model)
81+
}
82+
}
83+
84+
type ModelCache = Arc<RwLock<Option<ModelListResult>>>;
85+
6986
#[derive(Clone, Debug)]
7087
pub struct ApiClient {
7188
client: CodewhispererClient,
7289
streaming_client: Option<CodewhispererStreamingClient>,
7390
sigv4_streaming_client: Option<QDeveloperStreamingClient>,
7491
mock_client: Option<Arc<Mutex<std::vec::IntoIter<Vec<ChatResponseStream>>>>>,
7592
profile: Option<AuthProfile>,
93+
model_cache: ModelCache,
7694
}
7795

7896
impl ApiClient {
@@ -112,6 +130,7 @@ impl ApiClient {
112130
sigv4_streaming_client: None,
113131
mock_client: None,
114132
profile: None,
133+
model_cache: Arc::new(RwLock::new(None)),
115134
};
116135

117136
if let Ok(json) = env.get("Q_MOCK_CHAT_RESPONSE") {
@@ -146,6 +165,7 @@ impl ApiClient {
146165
.interceptor(UserAgentOverrideInterceptor::new())
147166
.app_name(app_name())
148167
.endpoint_url(endpoint.url())
168+
.retry_classifier(retry_classifier::QCliRetryClassifier::new())
149169
.stalled_stream_protection(stalled_stream_protection_config())
150170
.build(),
151171
));
@@ -159,6 +179,7 @@ impl ApiClient {
159179
.bearer_token_resolver(BearerResolver)
160180
.app_name(app_name())
161181
.endpoint_url(endpoint.url())
182+
.retry_classifier(retry_classifier::QCliRetryClassifier::new())
162183
.stalled_stream_protection(stalled_stream_protection_config())
163184
.build(),
164185
));
@@ -179,6 +200,7 @@ impl ApiClient {
179200
sigv4_streaming_client,
180201
mock_client: None,
181202
profile,
203+
model_cache: Arc::new(RwLock::new(None)),
182204
})
183205
}
184206

@@ -232,6 +254,82 @@ impl ApiClient {
232254
Ok(profiles)
233255
}
234256

257+
pub async fn list_available_models(&self) -> Result<ModelListResult, ApiClientError> {
258+
if cfg!(test) {
259+
let m = Model::builder()
260+
.model_id("model-1")
261+
.description("Test Model 1")
262+
.build()
263+
.unwrap();
264+
265+
return Ok(ModelListResult {
266+
models: vec![m.clone()],
267+
default_model: m,
268+
});
269+
}
270+
271+
let mut models = Vec::new();
272+
let mut default_model = None;
273+
let request = self
274+
.client
275+
.list_available_models()
276+
.set_origin(Some(Cli))
277+
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()));
278+
let mut paginator = request.into_paginator().send();
279+
280+
while let Some(result) = paginator.next().await {
281+
let models_output = result?;
282+
models.extend(models_output.models().iter().cloned());
283+
284+
if default_model.is_none() {
285+
default_model = Some(models_output.default_model().clone());
286+
}
287+
}
288+
let default_model = default_model.ok_or_else(|| ApiClientError::DefaultModelNotFound)?;
289+
Ok(ModelListResult { models, default_model })
290+
}
291+
292+
pub async fn list_available_models_cached(&self) -> Result<ModelListResult, ApiClientError> {
293+
{
294+
let cache = self.model_cache.read().await;
295+
if let Some(cached) = cache.as_ref() {
296+
tracing::debug!("Returning cached model list");
297+
return Ok(cached.clone());
298+
}
299+
}
300+
301+
tracing::debug!("Cache miss, fetching models from list_available_models API");
302+
let result = self.list_available_models().await?;
303+
{
304+
let mut cache = self.model_cache.write().await;
305+
*cache = Some(result.clone());
306+
}
307+
Ok(result)
308+
}
309+
310+
pub async fn invalidate_model_cache(&self) {
311+
let mut cache = self.model_cache.write().await;
312+
*cache = None;
313+
tracing::info!("Model cache invalidated");
314+
}
315+
316+
pub async fn get_available_models(&self, _region: &str) -> Result<ModelListResult, ApiClientError> {
317+
let res = self.list_available_models_cached().await?;
318+
// TODO: Once we have access to gpt-oss, add back.
319+
// if region == "us-east-1" {
320+
// let gpt_oss = Model::builder()
321+
// .model_id("OPENAI_GPT_OSS_120B_1_0")
322+
// .model_name("openai-gpt-oss-120b-preview")
323+
// .token_limits(TokenLimits::builder().max_input_tokens(128_000).build())
324+
// .build()
325+
// .map_err(ApiClientError::from)?;
326+
327+
// models.push(gpt_oss);
328+
// }
329+
330+
Ok(res)
331+
}
332+
235333
pub async fn create_subscription_token(&self) -> Result<CreateSubscriptionTokenOutput, ApiClientError> {
236334
if cfg!(test) {
237335
return Ok(CreateSubscriptionTokenOutput::builder()
@@ -496,7 +594,9 @@ fn timeout_config(database: &Database) -> TimeoutConfig {
496594
}
497595

498596
fn retry_config() -> RetryConfig {
499-
RetryConfig::standard().with_max_attempts(1)
597+
RetryConfig::adaptive()
598+
.with_max_attempts(3)
599+
.with_max_backoff(Duration::from_secs(10))
500600
}
501601

502602
pub fn stalled_stream_protection_config() -> StalledStreamProtectionConfig {

0 commit comments

Comments
 (0)