Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1d6fb47
cooprate with list avalibale models, remove model options
evanliu048 Jul 25, 2025
a282fcb
add model name mapping method
evanliu048 Jul 28, 2025
0a73763
Merge branch 'main' into list_model
evanliu048 Jul 28, 2025
5c31dcb
Merge branch 'main' into list_model
evanliu048 Jul 28, 2025
a57a0b0
cooperate with default model in response
evanliu048 Jul 28, 2025
3aff859
add cache in apiclient for model
evanliu048 Jul 29, 2025
3d3be47
update the model cache structure
evanliu048 Jul 31, 2025
f6d48b7
remove model id mapping in client
evanliu048 Aug 1, 2025
e746f9f
remove model id mapping in client
evanliu048 Aug 1, 2025
4f5d017
merge with main
evanliu048 Aug 5, 2025
b0d6177
merge main
evanliu048 Aug 6, 2025
4bcd58e
combine api res and openai model
evanliu048 Aug 7, 2025
33c2537
change the default model to be a required field
evanliu048 Aug 7, 2025
cb27bf9
change modelid in conversationstate into model info
evanliu048 Aug 7, 2025
73049e9
merge main
evanliu048 Aug 7, 2025
8853786
remove unused import
evanliu048 Aug 7, 2025
fbca06b
CI
evanliu048 Aug 7, 2025
2c959d0
add a fallback model when api failed
evanliu048 Aug 8, 2025
c4bfb19
resolve merge conflict
evanliu048 Aug 8, 2025
72a95a4
replace tuple with ModelListResult struct; cache struct; keep tuple A…
evanliu048 Aug 9, 2025
bff90b6
checking model name instead of model id
evanliu048 Aug 9, 2025
7ca9531
clippy
evanliu048 Aug 9, 2025
a92bb9c
merge main
evanliu048 Aug 11, 2025
a0af47e
add both modelid and model infor in conversation state
evanliu048 Aug 11, 2025
b5d07df
delete manul Deserialize for modelInfo
evanliu048 Aug 12, 2025
3067317
delete manul Deserialize for modelInfo
evanliu048 Aug 12, 2025
d42ca1d
add import
evanliu048 Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions crates/chat-cli/src/api_client/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError;
use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError;
use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError;
use amzn_codewhisperer_client::operation::list_available_models::ListAvailableModelsError;
use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError;
use amzn_codewhisperer_client::operation::send_telemetry_event::SendTelemetryEventError;
pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError;
Expand Down Expand Up @@ -93,6 +94,12 @@ pub enum ApiClientError {
// Credential errors
#[error("failed to load credentials: {}", .0)]
Credentials(CredentialsError),

#[error(transparent)]
ListAvailableModelsError(#[from] SdkError<ListAvailableModelsError, HttpResponse>),

#[error("No default model found in the ListAvailableModels API response")]
DefaultModelNotFound,
}

impl ApiClientError {
Expand All @@ -116,6 +123,8 @@ impl ApiClientError {
Self::ModelOverloadedError { status_code, .. } => *status_code,
Self::MonthlyLimitReached { status_code } => *status_code,
Self::Credentials(_e) => None,
Self::ListAvailableModelsError(e) => sdk_status_code(e),
Self::DefaultModelNotFound => None,
}
}
}
Expand All @@ -141,6 +150,8 @@ impl ReasonCode for ApiClientError {
Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(),
Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(),
Self::Credentials(_) => "CredentialsError".to_string(),
Self::ListAvailableModelsError(e) => sdk_error_code(e),
Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(),
}
}
}
Expand Down Expand Up @@ -188,6 +199,10 @@ mod tests {
ListAvailableCustomizationsError::unhandled("<unhandled>"),
response(),
)),
ApiClientError::ListAvailableModelsError(SdkError::service_error(
ListAvailableModelsError::unhandled("<unhandled>"),
response(),
)),
ApiClientError::ListAvailableServices(SdkError::service_error(
ListCustomizationsError::unhandled("<unhandled>"),
response(),
Expand Down
96 changes: 96 additions & 0 deletions crates/chat-cli/src/api_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use std::time::Duration;

use amzn_codewhisperer_client::Client as CodewhispererClient;
use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput;
use amzn_codewhisperer_client::types::Origin::Cli;
use amzn_codewhisperer_client::types::{
Model,
OptOutPreference,
SubscriptionStatus,
TelemetryEvent,
Expand All @@ -32,6 +34,7 @@ pub use error::ApiClientError;
use parking_lot::Mutex;
pub use profile::list_available_profiles;
use serde_json::Map;
use tokio::sync::RwLock;
use tracing::{
debug,
error,
Expand Down Expand Up @@ -66,13 +69,28 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
// TODO(bskiser): confirm timeout is updated to an appropriate value?
const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5);

#[derive(Clone, Debug)]
pub struct ModelListResult {
pub models: Vec<Model>,
pub default_model: Model,
}

impl From<ModelListResult> for (Vec<Model>, Model) {
fn from(v: ModelListResult) -> Self {
(v.models, v.default_model)
}
}

type ModelCache = Arc<RwLock<Option<ModelListResult>>>;

#[derive(Clone, Debug)]
pub struct ApiClient {
client: CodewhispererClient,
streaming_client: Option<CodewhispererStreamingClient>,
sigv4_streaming_client: Option<QDeveloperStreamingClient>,
mock_client: Option<Arc<Mutex<std::vec::IntoIter<Vec<ChatResponseStream>>>>>,
profile: Option<AuthProfile>,
model_cache: ModelCache,
}

impl ApiClient {
Expand Down Expand Up @@ -112,6 +130,7 @@ impl ApiClient {
sigv4_streaming_client: None,
mock_client: None,
profile: None,
model_cache: Arc::new(RwLock::new(None)),
};

if let Ok(json) = env.get("Q_MOCK_CHAT_RESPONSE") {
Expand Down Expand Up @@ -181,6 +200,7 @@ impl ApiClient {
sigv4_streaming_client,
mock_client: None,
profile,
model_cache: Arc::new(RwLock::new(None)),
})
}

Expand Down Expand Up @@ -234,6 +254,82 @@ impl ApiClient {
Ok(profiles)
}

pub async fn list_available_models(&self) -> Result<ModelListResult, ApiClientError> {
if cfg!(test) {
let m = Model::builder()
.model_id("model-1")
.description("Test Model 1")
.build()
.unwrap();

return Ok(ModelListResult {
models: vec![m.clone()],
default_model: m,
});
}

let mut models = Vec::new();
let mut default_model = None;
let request = self
.client
.list_available_models()
.set_origin(Some(Cli))
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()));
let mut paginator = request.into_paginator().send();

while let Some(result) = paginator.next().await {
let models_output = result?;
models.extend(models_output.models().iter().cloned());

if default_model.is_none() {
default_model = Some(models_output.default_model().clone());
}
}
let default_model = default_model.ok_or_else(|| ApiClientError::DefaultModelNotFound)?;
Ok(ModelListResult { models, default_model })
}

pub async fn list_available_models_cached(&self) -> Result<ModelListResult, ApiClientError> {
{
let cache = self.model_cache.read().await;
if let Some(cached) = cache.as_ref() {
tracing::debug!("Returning cached model list");
return Ok(cached.clone());
}
}

tracing::debug!("Cache miss, fetching models from list_available_models API");
let result = self.list_available_models().await?;
{
let mut cache = self.model_cache.write().await;
*cache = Some(result.clone());
}
Ok(result)
}

pub async fn invalidate_model_cache(&self) {
let mut cache = self.model_cache.write().await;
*cache = None;
tracing::info!("Model cache invalidated");
}

pub async fn get_available_models(&self, _region: &str) -> Result<ModelListResult, ApiClientError> {
let res = self.list_available_models_cached().await?;
// TODO: Once we have access to gpt-oss, add back.
// if region == "us-east-1" {
// let gpt_oss = Model::builder()
// .model_id("OPENAI_GPT_OSS_120B_1_0")
// .model_name("openai-gpt-oss-120b-preview")
// .token_limits(TokenLimits::builder().max_input_tokens(128_000).build())
// .build()
// .map_err(ApiClientError::from)?;

// models.push(gpt_oss);
// }

Ok(res)
}

pub async fn create_subscription_token(&self) -> Result<CreateSubscriptionTokenOutput, ApiClientError> {
if cfg!(test) {
return Ok(CreateSubscriptionTokenOutput::builder()
Expand Down
1 change: 1 addition & 0 deletions crates/chat-cli/src/auth/builder_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ impl BuilderIdToken {

/// Check if the token is for the internal amzn start URL (`https://amzn.awsapps.com/start`),
/// this implies the user will use midway for private specs
#[allow(dead_code)]
pub fn is_amzn_user(&self) -> bool {
matches!(&self.start_url, Some(url) if url == AMZN_START_URL)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/chat-cli/src/cli/chat/cli/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl ContextSubcommand {
execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?;
}

let context_files_max_size = calc_max_context_files_size(session.conversation.model.as_deref());
let context_files_max_size = calc_max_context_files_size(session.conversation.model_info.as_ref());
let mut files_as_vec = profile_context_files
.iter()
.map(|(path, content, _)| (path.clone(), content.clone()))
Expand Down
Loading
Loading