diff --git a/crates/chat-cli/src/api_client/error.rs b/crates/chat-cli/src/api_client/error.rs index 88f4a1f70d..37420fb72e 100644 --- a/crates/chat-cli/src/api_client/error.rs +++ b/crates/chat-cli/src/api_client/error.rs @@ -1,6 +1,7 @@ use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError; use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError; use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError; +use amzn_codewhisperer_client::operation::list_available_models::ListAvailableModelsError; use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError; use amzn_codewhisperer_client::operation::send_telemetry_event::SendTelemetryEventError; pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError; @@ -93,6 +94,12 @@ pub enum ApiClientError { // Credential errors #[error("failed to load credentials: {}", .0)] Credentials(CredentialsError), + + #[error(transparent)] + ListAvailableModelsError(#[from] SdkError), + + #[error("No default model found in the ListAvailableModels API response")] + DefaultModelNotFound, } impl ApiClientError { @@ -116,6 +123,8 @@ impl ApiClientError { Self::ModelOverloadedError { status_code, .. } => *status_code, Self::MonthlyLimitReached { status_code } => *status_code, Self::Credentials(_e) => None, + Self::ListAvailableModelsError(e) => sdk_status_code(e), + Self::DefaultModelNotFound => None, } } } @@ -141,6 +150,8 @@ impl ReasonCode for ApiClientError { Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(), Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(), Self::Credentials(_) => "CredentialsError".to_string(), + Self::ListAvailableModelsError(e) => sdk_error_code(e), + Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(), } } } @@ -188,6 +199,10 @@ mod tests { ListAvailableCustomizationsError::unhandled(""), response(), )), + ApiClientError::ListAvailableModelsError(SdkError::service_error( + ListAvailableModelsError::unhandled(""), + response(), + )), ApiClientError::ListAvailableServices(SdkError::service_error( ListCustomizationsError::unhandled(""), response(), diff --git a/crates/chat-cli/src/api_client/mod.rs b/crates/chat-cli/src/api_client/mod.rs index d76c1b6944..4baf757554 100644 --- a/crates/chat-cli/src/api_client/mod.rs +++ b/crates/chat-cli/src/api_client/mod.rs @@ -12,7 +12,9 @@ use std::time::Duration; use amzn_codewhisperer_client::Client as CodewhispererClient; use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput; +use amzn_codewhisperer_client::types::Origin::Cli; use amzn_codewhisperer_client::types::{ + Model, OptOutPreference, SubscriptionStatus, TelemetryEvent, @@ -32,6 +34,7 @@ pub use error::ApiClientError; use parking_lot::Mutex; pub use profile::list_available_profiles; use serde_json::Map; +use tokio::sync::RwLock; use tracing::{ debug, error, @@ -66,6 +69,20 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto // TODO(bskiser): confirm timeout is updated to an appropriate value? const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5); +#[derive(Clone, Debug)] +pub struct ModelListResult { + pub models: Vec, + pub default_model: Model, +} + +impl From for (Vec, Model) { + fn from(v: ModelListResult) -> Self { + (v.models, v.default_model) + } +} + +type ModelCache = Arc>>; + #[derive(Clone, Debug)] pub struct ApiClient { client: CodewhispererClient, @@ -73,6 +90,7 @@ pub struct ApiClient { sigv4_streaming_client: Option, mock_client: Option>>>>, profile: Option, + model_cache: ModelCache, } impl ApiClient { @@ -112,6 +130,7 @@ impl ApiClient { sigv4_streaming_client: None, mock_client: None, profile: None, + model_cache: Arc::new(RwLock::new(None)), }; if let Ok(json) = env.get("Q_MOCK_CHAT_RESPONSE") { @@ -181,6 +200,7 @@ impl ApiClient { sigv4_streaming_client, mock_client: None, profile, + model_cache: Arc::new(RwLock::new(None)), }) } @@ -234,6 +254,82 @@ impl ApiClient { Ok(profiles) } + pub async fn list_available_models(&self) -> Result { + if cfg!(test) { + let m = Model::builder() + .model_id("model-1") + .description("Test Model 1") + .build() + .unwrap(); + + return Ok(ModelListResult { + models: vec![m.clone()], + default_model: m, + }); + } + + let mut models = Vec::new(); + let mut default_model = None; + let request = self + .client + .list_available_models() + .set_origin(Some(Cli)) + .set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone())); + let mut paginator = request.into_paginator().send(); + + while let Some(result) = paginator.next().await { + let models_output = result?; + models.extend(models_output.models().iter().cloned()); + + if default_model.is_none() { + default_model = Some(models_output.default_model().clone()); + } + } + let default_model = default_model.ok_or_else(|| ApiClientError::DefaultModelNotFound)?; + Ok(ModelListResult { models, default_model }) + } + + pub async fn list_available_models_cached(&self) -> Result { + { + let cache = self.model_cache.read().await; + if let Some(cached) = cache.as_ref() { + tracing::debug!("Returning cached model list"); + return Ok(cached.clone()); + } + } + + tracing::debug!("Cache miss, fetching models from list_available_models API"); + let result = self.list_available_models().await?; + { + let mut cache = self.model_cache.write().await; + *cache = Some(result.clone()); + } + Ok(result) + } + + pub async fn invalidate_model_cache(&self) { + let mut cache = self.model_cache.write().await; + *cache = None; + tracing::info!("Model cache invalidated"); + } + + pub async fn get_available_models(&self, _region: &str) -> Result { + let res = self.list_available_models_cached().await?; + // TODO: Once we have access to gpt-oss, add back. + // if region == "us-east-1" { + // let gpt_oss = Model::builder() + // .model_id("OPENAI_GPT_OSS_120B_1_0") + // .model_name("openai-gpt-oss-120b-preview") + // .token_limits(TokenLimits::builder().max_input_tokens(128_000).build()) + // .build() + // .map_err(ApiClientError::from)?; + + // models.push(gpt_oss); + // } + + Ok(res) + } + pub async fn create_subscription_token(&self) -> Result { if cfg!(test) { return Ok(CreateSubscriptionTokenOutput::builder() diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs index 0154f1cb98..4261f7053a 100644 --- a/crates/chat-cli/src/auth/builder_id.rs +++ b/crates/chat-cli/src/auth/builder_id.rs @@ -480,6 +480,7 @@ impl BuilderIdToken { /// Check if the token is for the internal amzn start URL (`https://amzn.awsapps.com/start`), /// this implies the user will use midway for private specs + #[allow(dead_code)] pub fn is_amzn_user(&self) -> bool { matches!(&self.start_url, Some(url) if url == AMZN_START_URL) } diff --git a/crates/chat-cli/src/cli/chat/cli/context.rs b/crates/chat-cli/src/cli/chat/cli/context.rs index a1a115904e..df008330cf 100644 --- a/crates/chat-cli/src/cli/chat/cli/context.rs +++ b/crates/chat-cli/src/cli/chat/cli/context.rs @@ -222,7 +222,7 @@ impl ContextSubcommand { execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?; } - let context_files_max_size = calc_max_context_files_size(session.conversation.model.as_deref()); + let context_files_max_size = calc_max_context_files_size(session.conversation.model_info.as_ref()); let mut files_as_vec = profile_context_files .iter() .map(|(path, content, _)| (path.clone(), content.clone())) diff --git a/crates/chat-cli/src/cli/chat/cli/model.rs b/crates/chat-cli/src/cli/chat/cli/model.rs index e42e4eaf41..8d4dacd87d 100644 --- a/crates/chat-cli/src/cli/chat/cli/model.rs +++ b/crates/chat-cli/src/cli/chat/cli/model.rs @@ -1,3 +1,4 @@ +use amzn_codewhisperer_client::types::Model; use clap::Args; use crossterm::style::{ self, @@ -8,11 +9,12 @@ use crossterm::{ queue, }; use dialoguer::Select; - -use crate::auth::builder_id::{ - BuilderIdToken, - TokenType, +use serde::{ + Deserialize, + Serialize, }; + +use crate::api_client::Endpoint; use crate::cli::chat::{ ChatError, ChatSession, @@ -20,34 +22,44 @@ use crate::cli::chat::{ }; use crate::os::Os; -pub struct ModelOption { +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelInfo { /// Display name - pub name: &'static str, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_name: Option, /// Actual model id to send in the API - pub model_id: &'static str, + pub model_id: String, /// Size of the model's context window, in tokens + #[serde(default = "default_context_window")] pub context_window_tokens: usize, } -const MODEL_OPTIONS: [ModelOption; 2] = [ - ModelOption { - name: "claude-4-sonnet", - model_id: "CLAUDE_SONNET_4_20250514_V1_0", - context_window_tokens: 200_000, - }, - ModelOption { - name: "claude-3.7-sonnet", - model_id: "CLAUDE_3_7_SONNET_20250219_V1_0", - context_window_tokens: 200_000, - }, -]; - -const GPT_OSS_120B: ModelOption = ModelOption { - name: "openai-gpt-oss-120b-preview", - model_id: "OPENAI_GPT_OSS_120B_1_0", - context_window_tokens: 128_000, -}; +impl ModelInfo { + pub fn from_api_model(model: &Model) -> Self { + let context_window_tokens = model + .token_limits() + .and_then(|limits| limits.max_input_tokens()) + .map_or(default_context_window(), |tokens| tokens as usize); + Self { + model_id: model.model_id().to_string(), + model_name: model.model_name().map(|s| s.to_string()), + context_window_tokens, + } + } + /// create a default model with only valid model_id(be compatoble with old stored model data) + pub fn from_id(model_id: String) -> Self { + Self { + model_id, + model_name: None, + context_window_tokens: 200_000, + } + } + + pub fn display_name(&self) -> &str { + self.model_name.as_deref().unwrap_or(&self.model_id) + } +} #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct ModelArgs; @@ -62,16 +74,30 @@ impl ModelArgs { pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result, ChatError> { queue!(session.stderr, style::Print("\n"))?; - let active_model_id = session.conversation.model.as_deref(); - let model_options = get_model_options(os).await?; - let labels: Vec = model_options + // Fetch available models from service + let (models, _default_model) = get_available_models(os).await?; + + if models.is_empty() { + queue!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print("No models available\n"), + style::ResetColor + )?; + return Ok(None); + } + + let active_model_id = session.conversation.model_info.as_ref().map(|m| m.model_id.as_str()); + + let labels: Vec = models .iter() - .map(|opt| { - if (opt.model_id.is_empty() && active_model_id.is_none()) || Some(opt.model_id) == active_model_id { - format!("{} (active)", opt.name) + .map(|model| { + let display_name = model.display_name(); + if Some(model.model_id.as_str()) == active_model_id { + format!("{} (active)", display_name) } else { - opt.name.to_owned() + display_name.to_owned() } }) .collect(); @@ -97,14 +123,14 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result Result &'static str { - // Check FRA region first - if let Ok(Some(profile)) = os.database.get_auth_profile() { - if profile.arn.split(':').nth(3) == Some("eu-central-1") { - return "CLAUDE_3_7_SONNET_20250219_V1_0"; - } - } +pub async fn get_model_info(model_id: &str, os: &Os) -> Result { + let (models, _) = get_available_models(os).await?; - // Check if Amazon IDC user - if let Ok(Some(token)) = BuilderIdToken::load(&os.database).await { - if matches!(token.token_type(), TokenType::IamIdentityCenter) && token.is_amzn_user() { - return "CLAUDE_3_7_SONNET_20250219_V1_0"; - } - } - - // Default to 4.0 - "CLAUDE_SONNET_4_20250514_V1_0" + models + .into_iter() + .find(|m| m.model_id == model_id) + .ok_or_else(|| ChatError::Custom(format!("Model '{}' not found", model_id).into())) } -/// Returns the available models for use. -#[allow(unused_variables)] -pub async fn get_model_options(os: &Os) -> Result, ChatError> { - Ok(MODEL_OPTIONS.into_iter().collect::>()) - // TODO: Once we have access to gpt-oss, add back. - // let mut model_options = MODEL_OPTIONS.into_iter().collect::>(); - // - // // GPT OSS is only accessible in IAD. - // let endpoint = Endpoint::configured_value(&os.database); - // if endpoint.region().as_ref() != "us-east-1" { - // return Ok(model_options); - // } - // - // model_options.push(GPT_OSS_120B); - // Ok(model_options) +/// Get available models with caching support +pub async fn get_available_models(os: &Os) -> Result<(Vec, ModelInfo), ChatError> { + let endpoint = Endpoint::configured_value(&os.database); + let region = endpoint.region().as_ref(); + + match os.client.get_available_models(region).await { + Ok(api_res) => { + let models: Vec = api_res.models.iter().map(ModelInfo::from_api_model).collect(); + let default_model = ModelInfo::from_api_model(&api_res.default_model); + + tracing::debug!("Successfully fetched {} models from API", models.len()); + Ok((models, default_model)) + }, + // In case of API throttling or other errors, fall back to hardcoded models + Err(e) => { + tracing::error!("Failed to fetch models from API: {}, using fallback list", e); + + let models = get_fallback_models(); + let default_model = models[0].clone(); + + Ok((models, default_model)) + }, + } } /// Returns the context window length in tokens for the given model_id. -pub fn context_window_tokens(model_id: Option<&str>) -> usize { - const DEFAULT_CONTEXT_WINDOW_LENGTH: usize = 200_000; +/// Uses cached model data when available +pub fn context_window_tokens(model_info: Option<&ModelInfo>) -> usize { + model_info.map_or_else(default_context_window, |m| m.context_window_tokens) +} - let Some(model_id) = model_id else { - return DEFAULT_CONTEXT_WINDOW_LENGTH; - }; +fn default_context_window() -> usize { + 200_000 +} - MODEL_OPTIONS - .iter() - .chain(std::iter::once(&GPT_OSS_120B)) - .find(|m| m.model_id == model_id) - .map_or(DEFAULT_CONTEXT_WINDOW_LENGTH, |m| m.context_window_tokens) +fn get_fallback_models() -> Vec { + vec![ + ModelInfo { + model_name: Some("claude-3.7-sonnet".to_string()), + model_id: "claude-3.7-sonnet".to_string(), + context_window_tokens: 200_000, + }, + ModelInfo { + model_name: Some("claude-4-sonnet".to_string()), + model_id: "claude-4-sonnet".to_string(), + context_window_tokens: 200_000, + }, + ] } diff --git a/crates/chat-cli/src/cli/chat/cli/usage.rs b/crates/chat-cli/src/cli/chat/cli/usage.rs index 6d62ffb4ec..eca538e2b6 100644 --- a/crates/chat-cli/src/cli/chat/cli/usage.rs +++ b/crates/chat-cli/src/cli/chat/cli/usage.rs @@ -62,8 +62,7 @@ impl UsageArgs { // set a max width for the progress bar for better aesthetic let progress_bar_width = std::cmp::min(window_width, 80); - let context_window_size = context_window_tokens(session.conversation.model.as_deref()); - + let context_window_size = context_window_tokens(session.conversation.model_info.as_ref()); let context_width = ((context_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; let assistant_width = diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs index 993aa3cca9..1fdcc5e8ee 100644 --- a/crates/chat-cli/src/cli/chat/context.rs +++ b/crates/chat-cli/src/cli/chat/context.rs @@ -23,6 +23,7 @@ use crate::cli::agent::hook::{ }; use crate::cli::chat::ChatError; use crate::cli::chat::cli::hooks::HookExecutor; +use crate::cli::chat::cli::model::ModelInfo; use crate::os::Os; #[derive(Debug, Clone)] @@ -255,9 +256,9 @@ impl ContextManager { } /// Calculates the maximum context files size to use for the given model id. -pub fn calc_max_context_files_size(model_id: Option<&str>) -> usize { +pub fn calc_max_context_files_size(model: Option<&ModelInfo>) -> usize { // Sets the max as 75% of the context window - context_window_tokens(model_id).saturating_mul(3) / 4 + context_window_tokens(model).saturating_mul(3) / 4 } /// Process a path, handling glob patterns and file types. @@ -434,9 +435,20 @@ mod tests { #[test] fn test_calc_max_context_files_size() { assert_eq!( - calc_max_context_files_size(Some("CLAUDE_SONNET_4_20250514_V1_0")), + calc_max_context_files_size(Some(&ModelInfo { + model_id: "CLAUDE_SONNET_4_20250514_V1_0".to_string(), + model_name: Some("Claude".to_string()), + context_window_tokens: 200_000, + })), 150_000 ); - assert_eq!(calc_max_context_files_size(Some("OPENAI_GPT_OSS_120B_1_0")), 96_000); + assert_eq!( + calc_max_context_files_size(Some(&ModelInfo { + model_id: "OPENAI_GPT_OSS_120B_1_0".to_string(), + model_name: Some("GPT".to_string()), + context_window_tokens: 128_000, + })), + 96_000 + ); } } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 1c3a0f8e7b..ca7b87d2c4 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -66,6 +66,10 @@ use crate::cli::agent::hook::{ HookTrigger, }; use crate::cli::chat::ChatError; +use crate::cli::chat::cli::model::{ + ModelInfo, + get_model_info, +}; use crate::mcp_client::Prompt; use crate::os::Os; @@ -108,9 +112,13 @@ pub struct ConversationState { latest_summary: Option<(String, RequestMetadata)>, #[serde(skip)] pub agents: Agents, + /// Unused, kept only to maintain deserialization backwards compatibility with <=v1.13.3 /// Model explicitly selected by the user in this conversation state via `/model`. #[serde(default, skip_serializing_if = "Option::is_none")] pub model: Option, + /// Model explicitly selected by the user in this conversation state via `/model`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model_info: Option, /// Used to track agent vs user updates to file modifications. /// /// Maps from a file path to [FileLineTracker] @@ -125,9 +133,22 @@ impl ConversationState { tool_config: HashMap, tool_manager: ToolManager, current_model_id: Option, + os: &Os, ) -> Self { + let model = if let Some(model_id) = current_model_id { + match get_model_info(&model_id, os).await { + Ok(info) => Some(info), + Err(e) => { + tracing::warn!("Failed to get model info for {}: {}, using default", model_id, e); + Some(ModelInfo::from_id(model_id)) + }, + } + } else { + None + }; + let context_manager = if let Some(agent) = agents.get_active() { - ContextManager::from_agent(agent, calc_max_context_files_size(current_model_id.as_deref())).ok() + ContextManager::from_agent(agent, calc_max_context_files_size(model.as_ref())).ok() } else { None }; @@ -156,7 +177,8 @@ impl ConversationState { context_message_length: None, latest_summary: None, agents, - model: current_model_id, + model: None, + model_info: model, file_line_tracker: HashMap::new(), } } @@ -440,7 +462,7 @@ impl ConversationState { context_messages, dropped_context_files, tools: &self.tools, - model_id: self.model.as_deref(), + model_id: self.model_info.as_ref().map(|m| m.model_id.as_str()), }) } @@ -538,7 +560,7 @@ impl ConversationState { conversation_id: Some(self.conversation_id.clone()), user_input_message: summary_message .unwrap_or(UserMessage::new_prompt(summary_content, None)) // should not happen - .into_user_input_message(self.model.clone(), &tools), + .into_user_input_message(self.model_info.as_ref().map(|m| m.model_id.clone()), &tools), history: Some(flatten_history(history.iter())), }) } @@ -651,7 +673,7 @@ impl ConversationState { /// Get the current token warning level pub async fn get_token_warning_level(&mut self, os: &Os) -> Result { let total_chars = self.calculate_char_count(os).await?; - let max_chars = TokenCounter::token_to_chars(context_window_tokens(self.model.as_deref())); + let max_chars = TokenCounter::token_to_chars(context_window_tokens(self.model_info.as_ref())); Ok(if *total_chars >= max_chars { TokenWarningLevel::Critical @@ -1076,6 +1098,7 @@ mod tests { tool_manager.load_tools(&mut os, &mut output).await.unwrap(), tool_manager, None, + &os, ) .await; @@ -1107,6 +1130,7 @@ mod tests { tool_config.clone(), tool_manager.clone(), None, + &os, ) .await; conversation.set_next_user_message("start".to_string()).await; @@ -1135,8 +1159,15 @@ mod tests { } // Build a long conversation history of user messages mixed in with tool results. - let mut conversation = - ConversationState::new("fake_conv_id", agents, tool_config.clone(), tool_manager.clone(), None).await; + let mut conversation = ConversationState::new( + "fake_conv_id", + agents, + tool_config.clone(), + tool_manager.clone(), + None, + &os, + ) + .await; conversation.set_next_user_message("start".to_string()).await; for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { let s = conversation @@ -1188,6 +1219,7 @@ mod tests { tool_manager.load_tools(&mut os, &mut output).await.unwrap(), tool_manager, None, + &os, ) .await; diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 2ead9b90d5..0941df73a3 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -18,7 +18,6 @@ mod token_counter; pub mod tool_manager; pub mod tools; pub mod util; - use std::borrow::Cow; use std::collections::{ HashMap, @@ -44,7 +43,7 @@ use clap::{ }; use cli::compact::CompactStrategy; use cli::model::{ - get_model_options, + get_available_models, select_model, }; pub use conversation::ConversationState; @@ -134,7 +133,6 @@ use crate::auth::AuthError; use crate::auth::builder_id::is_idc_user; use crate::cli::agent::Agents; use crate::cli::chat::cli::SlashCommand; -use crate::cli::chat::cli::model::default_model_id; use crate::cli::chat::cli::prompts::{ GetPromptError, PromptsSubcommand, @@ -314,22 +312,36 @@ impl ChatArgs { }; // If modelId is specified, verify it exists before starting the chat - let model_options = get_model_options(os).await?; - let model_id: Option = if let Some(model_name) = self.model { - let model_name_lower = model_name.to_lowercase(); - match model_options.iter().find(|opt| opt.name == model_name_lower) { - Some(opt) => Some((opt.model_id).to_string()), - None => { - let available_names: Vec<&str> = model_options.iter().map(|opt| opt.name).collect(); - bail!( - "Model '{}' does not exist. Available models: {}", - model_name, - available_names.join(", ") - ); - }, + // Otherwise, CLI will use a default model when starting chat + let (models, default_model_opt) = get_available_models(os).await?; + let model_id: Option = if let Some(requested) = self.model.as_ref() { + let requested_lower = requested.to_lowercase(); + if let Some(m) = models.iter().find(|m| { + m.model_name + .as_deref() + .is_some_and(|n| n.eq_ignore_ascii_case(&requested_lower)) + || m.model_id.eq_ignore_ascii_case(&requested_lower) + }) { + Some(m.model_id.clone()) + } else { + let available = models + .iter() + .map(|m| m.model_name.as_deref().unwrap_or(&m.model_id)) + .collect::>() + .join(", "); + bail!("Model '{}' does not exist. Available models: {}", requested, available); + } + } else if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { + if let Some(m) = models.iter().find(|m| { + m.model_name.as_deref().is_some_and(|n| n.eq_ignore_ascii_case(&saved)) + || m.model_id.eq_ignore_ascii_case(&saved) + }) { + Some(m.model_id.clone()) + } else { + Some(default_model_opt.model_id.clone()) } } else { - None + Some(default_model_opt.model_id.clone()) }; let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::>(); @@ -579,28 +591,6 @@ impl ChatSession { tool_config: HashMap, interactive: bool, ) -> Result { - let model_options = get_model_options(os).await?; - let valid_model_id = match model_id { - Some(id) => id, - None => { - let from_settings = os - .database - .settings - .get_string(Setting::ChatDefaultModel) - .and_then(|model_name| { - model_options - .iter() - .find(|opt| opt.name == model_name) - .map(|opt| opt.model_id.to_owned()) - }); - - match from_settings { - Some(id) => id, - None => default_model_id(os).await.to_owned(), - } - }, - }; - // Reload prior conversation let mut existing_conversation = false; let previous_conversation = std::env::current_dir() @@ -639,9 +629,7 @@ impl ChatSession { cs.enforce_tool_use_history_invariants(); cs }, - false => { - ConversationState::new(conversation_id, agents, tool_config, tool_manager, Some(valid_model_id)).await - }, + false => ConversationState::new(conversation_id, agents, tool_config, tool_manager, model_id, os).await, }; // Spawn a task for listening and broadcasting sigints. @@ -1196,13 +1184,14 @@ impl ChatSession { } self.stderr.flush()?; - if let Some(ref id) = self.conversation.model { - let model_options = get_model_options(os).await?; - if let Some(model_option) = model_options.iter().find(|option| option.model_id == *id) { + if let Some(ref model_info) = self.conversation.model_info { + let (models, _default_model) = get_available_models(os).await?; + if let Some(model_option) = models.iter().find(|option| option.model_id == model_info.model_id) { + let display_name = model_option.model_name.as_deref().unwrap_or(&model_option.model_id); execute!( self.stderr, style::SetForegroundColor(Color::Cyan), - style::Print(format!("🤖 You are chatting with {}\n", model_option.name)), + style::Print(format!("🤖 You are chatting with {}\n", display_name)), style::SetForegroundColor(Color::Reset), style::Print("\n") )?; @@ -2385,11 +2374,14 @@ impl ChatSession { for tool_use in tool_uses { let tool_use_id = tool_use.id.clone(); let tool_use_name = tool_use.name.clone(); - let mut tool_telemetry = - ToolUseEventBuilder::new(conv_id.clone(), tool_use.id.clone(), self.conversation.model.clone()) - .set_tool_use_id(tool_use_id.clone()) - .set_tool_name(tool_use.name.clone()) - .utterance_id(self.conversation.message_id().map(|s| s.to_string())); + let mut tool_telemetry = ToolUseEventBuilder::new( + conv_id.clone(), + tool_use.id.clone(), + self.conversation.model_info.as_ref().map(|m| m.model_id.clone()), + ) + .set_tool_use_id(tool_use_id.clone()) + .set_tool_name(tool_use.name.clone()) + .utterance_id(self.conversation.message_id().map(|s| s.to_string())); match self.conversation.tool_manager.get_tool_from_tool_use(tool_use) { Ok(mut tool) => { // Apply non-Q-generated context to tools @@ -2481,6 +2473,7 @@ impl ChatSession { } async fn retry_model_overload(&mut self, os: &mut Os) -> Result { + os.client.invalidate_model_cache().await; match select_model(os, self).await { Ok(Some(_)) => (), Ok(None) => {