diff --git a/crates/chat-cli/src/api_client/error.rs b/crates/chat-cli/src/api_client/error.rs index 37420fb72e..88f4a1f70d 100644 --- a/crates/chat-cli/src/api_client/error.rs +++ b/crates/chat-cli/src/api_client/error.rs @@ -1,7 +1,6 @@ use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError; use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError; use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError; -use amzn_codewhisperer_client::operation::list_available_models::ListAvailableModelsError; use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError; use amzn_codewhisperer_client::operation::send_telemetry_event::SendTelemetryEventError; pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError; @@ -94,12 +93,6 @@ pub enum ApiClientError { // Credential errors #[error("failed to load credentials: {}", .0)] Credentials(CredentialsError), - - #[error(transparent)] - ListAvailableModelsError(#[from] SdkError), - - #[error("No default model found in the ListAvailableModels API response")] - DefaultModelNotFound, } impl ApiClientError { @@ -123,8 +116,6 @@ impl ApiClientError { Self::ModelOverloadedError { status_code, .. } => *status_code, Self::MonthlyLimitReached { status_code } => *status_code, Self::Credentials(_e) => None, - Self::ListAvailableModelsError(e) => sdk_status_code(e), - Self::DefaultModelNotFound => None, } } } @@ -150,8 +141,6 @@ impl ReasonCode for ApiClientError { Self::ModelOverloadedError { .. } => "ModelOverloadedError".to_string(), Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(), Self::Credentials(_) => "CredentialsError".to_string(), - Self::ListAvailableModelsError(e) => sdk_error_code(e), - Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(), } } } @@ -199,10 +188,6 @@ mod tests { ListAvailableCustomizationsError::unhandled(""), response(), )), - ApiClientError::ListAvailableModelsError(SdkError::service_error( - ListAvailableModelsError::unhandled(""), - response(), - )), ApiClientError::ListAvailableServices(SdkError::service_error( ListCustomizationsError::unhandled(""), response(), diff --git a/crates/chat-cli/src/api_client/mod.rs b/crates/chat-cli/src/api_client/mod.rs index 4baf757554..d76c1b6944 100644 --- a/crates/chat-cli/src/api_client/mod.rs +++ b/crates/chat-cli/src/api_client/mod.rs @@ -12,9 +12,7 @@ use std::time::Duration; use amzn_codewhisperer_client::Client as CodewhispererClient; use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenOutput; -use amzn_codewhisperer_client::types::Origin::Cli; use amzn_codewhisperer_client::types::{ - Model, OptOutPreference, SubscriptionStatus, TelemetryEvent, @@ -34,7 +32,6 @@ pub use error::ApiClientError; use parking_lot::Mutex; pub use profile::list_available_profiles; use serde_json::Map; -use tokio::sync::RwLock; use tracing::{ debug, error, @@ -69,20 +66,6 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto // TODO(bskiser): confirm timeout is updated to an appropriate value? const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5); -#[derive(Clone, Debug)] -pub struct ModelListResult { - pub models: Vec, - pub default_model: Model, -} - -impl From for (Vec, Model) { - fn from(v: ModelListResult) -> Self { - (v.models, v.default_model) - } -} - -type ModelCache = Arc>>; - #[derive(Clone, Debug)] pub struct ApiClient { client: CodewhispererClient, @@ -90,7 +73,6 @@ pub struct ApiClient { sigv4_streaming_client: Option, mock_client: Option>>>>, profile: Option, - model_cache: ModelCache, } impl ApiClient { @@ -130,7 +112,6 @@ impl ApiClient { sigv4_streaming_client: None, mock_client: None, profile: None, - model_cache: Arc::new(RwLock::new(None)), }; if let Ok(json) = env.get("Q_MOCK_CHAT_RESPONSE") { @@ -200,7 +181,6 @@ impl ApiClient { sigv4_streaming_client, mock_client: None, profile, - model_cache: Arc::new(RwLock::new(None)), }) } @@ -254,82 +234,6 @@ impl ApiClient { Ok(profiles) } - pub async fn list_available_models(&self) -> Result { - if cfg!(test) { - let m = Model::builder() - .model_id("model-1") - .description("Test Model 1") - .build() - .unwrap(); - - return Ok(ModelListResult { - models: vec![m.clone()], - default_model: m, - }); - } - - let mut models = Vec::new(); - let mut default_model = None; - let request = self - .client - .list_available_models() - .set_origin(Some(Cli)) - .set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone())); - let mut paginator = request.into_paginator().send(); - - while let Some(result) = paginator.next().await { - let models_output = result?; - models.extend(models_output.models().iter().cloned()); - - if default_model.is_none() { - default_model = Some(models_output.default_model().clone()); - } - } - let default_model = default_model.ok_or_else(|| ApiClientError::DefaultModelNotFound)?; - Ok(ModelListResult { models, default_model }) - } - - pub async fn list_available_models_cached(&self) -> Result { - { - let cache = self.model_cache.read().await; - if let Some(cached) = cache.as_ref() { - tracing::debug!("Returning cached model list"); - return Ok(cached.clone()); - } - } - - tracing::debug!("Cache miss, fetching models from list_available_models API"); - let result = self.list_available_models().await?; - { - let mut cache = self.model_cache.write().await; - *cache = Some(result.clone()); - } - Ok(result) - } - - pub async fn invalidate_model_cache(&self) { - let mut cache = self.model_cache.write().await; - *cache = None; - tracing::info!("Model cache invalidated"); - } - - pub async fn get_available_models(&self, _region: &str) -> Result { - let res = self.list_available_models_cached().await?; - // TODO: Once we have access to gpt-oss, add back. - // if region == "us-east-1" { - // let gpt_oss = Model::builder() - // .model_id("OPENAI_GPT_OSS_120B_1_0") - // .model_name("openai-gpt-oss-120b-preview") - // .token_limits(TokenLimits::builder().max_input_tokens(128_000).build()) - // .build() - // .map_err(ApiClientError::from)?; - - // models.push(gpt_oss); - // } - - Ok(res) - } - pub async fn create_subscription_token(&self) -> Result { if cfg!(test) { return Ok(CreateSubscriptionTokenOutput::builder() diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs index 4261f7053a..0154f1cb98 100644 --- a/crates/chat-cli/src/auth/builder_id.rs +++ b/crates/chat-cli/src/auth/builder_id.rs @@ -480,7 +480,6 @@ impl BuilderIdToken { /// Check if the token is for the internal amzn start URL (`https://amzn.awsapps.com/start`), /// this implies the user will use midway for private specs - #[allow(dead_code)] pub fn is_amzn_user(&self) -> bool { matches!(&self.start_url, Some(url) if url == AMZN_START_URL) } diff --git a/crates/chat-cli/src/cli/chat/cli/context.rs b/crates/chat-cli/src/cli/chat/cli/context.rs index df008330cf..a1a115904e 100644 --- a/crates/chat-cli/src/cli/chat/cli/context.rs +++ b/crates/chat-cli/src/cli/chat/cli/context.rs @@ -222,7 +222,7 @@ impl ContextSubcommand { execute!(session.stderr, style::Print(format!("{}\n\n", "▔".repeat(3))),)?; } - let context_files_max_size = calc_max_context_files_size(session.conversation.model_info.as_ref()); + let context_files_max_size = calc_max_context_files_size(session.conversation.model.as_deref()); let mut files_as_vec = profile_context_files .iter() .map(|(path, content, _)| (path.clone(), content.clone())) diff --git a/crates/chat-cli/src/cli/chat/cli/model.rs b/crates/chat-cli/src/cli/chat/cli/model.rs index 8d4dacd87d..e42e4eaf41 100644 --- a/crates/chat-cli/src/cli/chat/cli/model.rs +++ b/crates/chat-cli/src/cli/chat/cli/model.rs @@ -1,4 +1,3 @@ -use amzn_codewhisperer_client::types::Model; use clap::Args; use crossterm::style::{ self, @@ -9,12 +8,11 @@ use crossterm::{ queue, }; use dialoguer::Select; -use serde::{ - Deserialize, - Serialize, -}; -use crate::api_client::Endpoint; +use crate::auth::builder_id::{ + BuilderIdToken, + TokenType, +}; use crate::cli::chat::{ ChatError, ChatSession, @@ -22,44 +20,34 @@ use crate::cli::chat::{ }; use crate::os::Os; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelInfo { +pub struct ModelOption { /// Display name - #[serde(skip_serializing_if = "Option::is_none")] - pub model_name: Option, + pub name: &'static str, /// Actual model id to send in the API - pub model_id: String, + pub model_id: &'static str, /// Size of the model's context window, in tokens - #[serde(default = "default_context_window")] pub context_window_tokens: usize, } -impl ModelInfo { - pub fn from_api_model(model: &Model) -> Self { - let context_window_tokens = model - .token_limits() - .and_then(|limits| limits.max_input_tokens()) - .map_or(default_context_window(), |tokens| tokens as usize); - Self { - model_id: model.model_id().to_string(), - model_name: model.model_name().map(|s| s.to_string()), - context_window_tokens, - } - } - - /// create a default model with only valid model_id(be compatoble with old stored model data) - pub fn from_id(model_id: String) -> Self { - Self { - model_id, - model_name: None, - context_window_tokens: 200_000, - } - } +const MODEL_OPTIONS: [ModelOption; 2] = [ + ModelOption { + name: "claude-4-sonnet", + model_id: "CLAUDE_SONNET_4_20250514_V1_0", + context_window_tokens: 200_000, + }, + ModelOption { + name: "claude-3.7-sonnet", + model_id: "CLAUDE_3_7_SONNET_20250219_V1_0", + context_window_tokens: 200_000, + }, +]; + +const GPT_OSS_120B: ModelOption = ModelOption { + name: "openai-gpt-oss-120b-preview", + model_id: "OPENAI_GPT_OSS_120B_1_0", + context_window_tokens: 128_000, +}; - pub fn display_name(&self) -> &str { - self.model_name.as_deref().unwrap_or(&self.model_id) - } -} #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] pub struct ModelArgs; @@ -74,30 +62,16 @@ impl ModelArgs { pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result, ChatError> { queue!(session.stderr, style::Print("\n"))?; + let active_model_id = session.conversation.model.as_deref(); + let model_options = get_model_options(os).await?; - // Fetch available models from service - let (models, _default_model) = get_available_models(os).await?; - - if models.is_empty() { - queue!( - session.stderr, - style::SetForegroundColor(Color::Red), - style::Print("No models available\n"), - style::ResetColor - )?; - return Ok(None); - } - - let active_model_id = session.conversation.model_info.as_ref().map(|m| m.model_id.as_str()); - - let labels: Vec = models + let labels: Vec = model_options .iter() - .map(|model| { - let display_name = model.display_name(); - if Some(model.model_id.as_str()) == active_model_id { - format!("{} (active)", display_name) + .map(|opt| { + if (opt.model_id.is_empty() && active_model_id.is_none()) || Some(opt.model_id) == active_model_id { + format!("{} (active)", opt.name) } else { - display_name.to_owned() + opt.name.to_owned() } }) .collect(); @@ -123,14 +97,14 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result Result Result { - let (models, _) = get_available_models(os).await?; - - models - .into_iter() - .find(|m| m.model_id == model_id) - .ok_or_else(|| ChatError::Custom(format!("Model '{}' not found", model_id).into())) -} - -/// Get available models with caching support -pub async fn get_available_models(os: &Os) -> Result<(Vec, ModelInfo), ChatError> { - let endpoint = Endpoint::configured_value(&os.database); - let region = endpoint.region().as_ref(); - - match os.client.get_available_models(region).await { - Ok(api_res) => { - let models: Vec = api_res.models.iter().map(ModelInfo::from_api_model).collect(); - let default_model = ModelInfo::from_api_model(&api_res.default_model); +/// Returns a default model id to use if none has been otherwise provided. +/// +/// Returns Claude 3.7 for: Amazon IDC users, FRA region users +/// Returns Claude 4.0 for: Builder ID users, other regions +pub async fn default_model_id(os: &Os) -> &'static str { + // Check FRA region first + if let Ok(Some(profile)) = os.database.get_auth_profile() { + if profile.arn.split(':').nth(3) == Some("eu-central-1") { + return "CLAUDE_3_7_SONNET_20250219_V1_0"; + } + } - tracing::debug!("Successfully fetched {} models from API", models.len()); - Ok((models, default_model)) - }, - // In case of API throttling or other errors, fall back to hardcoded models - Err(e) => { - tracing::error!("Failed to fetch models from API: {}, using fallback list", e); + // Check if Amazon IDC user + if let Ok(Some(token)) = BuilderIdToken::load(&os.database).await { + if matches!(token.token_type(), TokenType::IamIdentityCenter) && token.is_amzn_user() { + return "CLAUDE_3_7_SONNET_20250219_V1_0"; + } + } - let models = get_fallback_models(); - let default_model = models[0].clone(); + // Default to 4.0 + "CLAUDE_SONNET_4_20250514_V1_0" +} - Ok((models, default_model)) - }, - } +/// Returns the available models for use. +#[allow(unused_variables)] +pub async fn get_model_options(os: &Os) -> Result, ChatError> { + Ok(MODEL_OPTIONS.into_iter().collect::>()) + // TODO: Once we have access to gpt-oss, add back. + // let mut model_options = MODEL_OPTIONS.into_iter().collect::>(); + // + // // GPT OSS is only accessible in IAD. + // let endpoint = Endpoint::configured_value(&os.database); + // if endpoint.region().as_ref() != "us-east-1" { + // return Ok(model_options); + // } + // + // model_options.push(GPT_OSS_120B); + // Ok(model_options) } /// Returns the context window length in tokens for the given model_id. -/// Uses cached model data when available -pub fn context_window_tokens(model_info: Option<&ModelInfo>) -> usize { - model_info.map_or_else(default_context_window, |m| m.context_window_tokens) -} +pub fn context_window_tokens(model_id: Option<&str>) -> usize { + const DEFAULT_CONTEXT_WINDOW_LENGTH: usize = 200_000; -fn default_context_window() -> usize { - 200_000 -} + let Some(model_id) = model_id else { + return DEFAULT_CONTEXT_WINDOW_LENGTH; + }; -fn get_fallback_models() -> Vec { - vec![ - ModelInfo { - model_name: Some("claude-3.7-sonnet".to_string()), - model_id: "claude-3.7-sonnet".to_string(), - context_window_tokens: 200_000, - }, - ModelInfo { - model_name: Some("claude-4-sonnet".to_string()), - model_id: "claude-4-sonnet".to_string(), - context_window_tokens: 200_000, - }, - ] + MODEL_OPTIONS + .iter() + .chain(std::iter::once(&GPT_OSS_120B)) + .find(|m| m.model_id == model_id) + .map_or(DEFAULT_CONTEXT_WINDOW_LENGTH, |m| m.context_window_tokens) } diff --git a/crates/chat-cli/src/cli/chat/cli/usage.rs b/crates/chat-cli/src/cli/chat/cli/usage.rs index eca538e2b6..6d62ffb4ec 100644 --- a/crates/chat-cli/src/cli/chat/cli/usage.rs +++ b/crates/chat-cli/src/cli/chat/cli/usage.rs @@ -62,7 +62,8 @@ impl UsageArgs { // set a max width for the progress bar for better aesthetic let progress_bar_width = std::cmp::min(window_width, 80); - let context_window_size = context_window_tokens(session.conversation.model_info.as_ref()); + let context_window_size = context_window_tokens(session.conversation.model.as_deref()); + let context_width = ((context_token_count.value() as f64 / context_window_size as f64) * progress_bar_width as f64) as usize; let assistant_width = diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs index 1fdcc5e8ee..993aa3cca9 100644 --- a/crates/chat-cli/src/cli/chat/context.rs +++ b/crates/chat-cli/src/cli/chat/context.rs @@ -23,7 +23,6 @@ use crate::cli::agent::hook::{ }; use crate::cli::chat::ChatError; use crate::cli::chat::cli::hooks::HookExecutor; -use crate::cli::chat::cli::model::ModelInfo; use crate::os::Os; #[derive(Debug, Clone)] @@ -256,9 +255,9 @@ impl ContextManager { } /// Calculates the maximum context files size to use for the given model id. -pub fn calc_max_context_files_size(model: Option<&ModelInfo>) -> usize { +pub fn calc_max_context_files_size(model_id: Option<&str>) -> usize { // Sets the max as 75% of the context window - context_window_tokens(model).saturating_mul(3) / 4 + context_window_tokens(model_id).saturating_mul(3) / 4 } /// Process a path, handling glob patterns and file types. @@ -435,20 +434,9 @@ mod tests { #[test] fn test_calc_max_context_files_size() { assert_eq!( - calc_max_context_files_size(Some(&ModelInfo { - model_id: "CLAUDE_SONNET_4_20250514_V1_0".to_string(), - model_name: Some("Claude".to_string()), - context_window_tokens: 200_000, - })), + calc_max_context_files_size(Some("CLAUDE_SONNET_4_20250514_V1_0")), 150_000 ); - assert_eq!( - calc_max_context_files_size(Some(&ModelInfo { - model_id: "OPENAI_GPT_OSS_120B_1_0".to_string(), - model_name: Some("GPT".to_string()), - context_window_tokens: 128_000, - })), - 96_000 - ); + assert_eq!(calc_max_context_files_size(Some("OPENAI_GPT_OSS_120B_1_0")), 96_000); } } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index ca7b87d2c4..1c3a0f8e7b 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -66,10 +66,6 @@ use crate::cli::agent::hook::{ HookTrigger, }; use crate::cli::chat::ChatError; -use crate::cli::chat::cli::model::{ - ModelInfo, - get_model_info, -}; use crate::mcp_client::Prompt; use crate::os::Os; @@ -112,13 +108,9 @@ pub struct ConversationState { latest_summary: Option<(String, RequestMetadata)>, #[serde(skip)] pub agents: Agents, - /// Unused, kept only to maintain deserialization backwards compatibility with <=v1.13.3 /// Model explicitly selected by the user in this conversation state via `/model`. #[serde(default, skip_serializing_if = "Option::is_none")] pub model: Option, - /// Model explicitly selected by the user in this conversation state via `/model`. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub model_info: Option, /// Used to track agent vs user updates to file modifications. /// /// Maps from a file path to [FileLineTracker] @@ -133,22 +125,9 @@ impl ConversationState { tool_config: HashMap, tool_manager: ToolManager, current_model_id: Option, - os: &Os, ) -> Self { - let model = if let Some(model_id) = current_model_id { - match get_model_info(&model_id, os).await { - Ok(info) => Some(info), - Err(e) => { - tracing::warn!("Failed to get model info for {}: {}, using default", model_id, e); - Some(ModelInfo::from_id(model_id)) - }, - } - } else { - None - }; - let context_manager = if let Some(agent) = agents.get_active() { - ContextManager::from_agent(agent, calc_max_context_files_size(model.as_ref())).ok() + ContextManager::from_agent(agent, calc_max_context_files_size(current_model_id.as_deref())).ok() } else { None }; @@ -177,8 +156,7 @@ impl ConversationState { context_message_length: None, latest_summary: None, agents, - model: None, - model_info: model, + model: current_model_id, file_line_tracker: HashMap::new(), } } @@ -462,7 +440,7 @@ impl ConversationState { context_messages, dropped_context_files, tools: &self.tools, - model_id: self.model_info.as_ref().map(|m| m.model_id.as_str()), + model_id: self.model.as_deref(), }) } @@ -560,7 +538,7 @@ impl ConversationState { conversation_id: Some(self.conversation_id.clone()), user_input_message: summary_message .unwrap_or(UserMessage::new_prompt(summary_content, None)) // should not happen - .into_user_input_message(self.model_info.as_ref().map(|m| m.model_id.clone()), &tools), + .into_user_input_message(self.model.clone(), &tools), history: Some(flatten_history(history.iter())), }) } @@ -673,7 +651,7 @@ impl ConversationState { /// Get the current token warning level pub async fn get_token_warning_level(&mut self, os: &Os) -> Result { let total_chars = self.calculate_char_count(os).await?; - let max_chars = TokenCounter::token_to_chars(context_window_tokens(self.model_info.as_ref())); + let max_chars = TokenCounter::token_to_chars(context_window_tokens(self.model.as_deref())); Ok(if *total_chars >= max_chars { TokenWarningLevel::Critical @@ -1098,7 +1076,6 @@ mod tests { tool_manager.load_tools(&mut os, &mut output).await.unwrap(), tool_manager, None, - &os, ) .await; @@ -1130,7 +1107,6 @@ mod tests { tool_config.clone(), tool_manager.clone(), None, - &os, ) .await; conversation.set_next_user_message("start".to_string()).await; @@ -1159,15 +1135,8 @@ mod tests { } // Build a long conversation history of user messages mixed in with tool results. - let mut conversation = ConversationState::new( - "fake_conv_id", - agents, - tool_config.clone(), - tool_manager.clone(), - None, - &os, - ) - .await; + let mut conversation = + ConversationState::new("fake_conv_id", agents, tool_config.clone(), tool_manager.clone(), None).await; conversation.set_next_user_message("start".to_string()).await; for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { let s = conversation @@ -1219,7 +1188,6 @@ mod tests { tool_manager.load_tools(&mut os, &mut output).await.unwrap(), tool_manager, None, - &os, ) .await; diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 0941df73a3..2ead9b90d5 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -18,6 +18,7 @@ mod token_counter; pub mod tool_manager; pub mod tools; pub mod util; + use std::borrow::Cow; use std::collections::{ HashMap, @@ -43,7 +44,7 @@ use clap::{ }; use cli::compact::CompactStrategy; use cli::model::{ - get_available_models, + get_model_options, select_model, }; pub use conversation::ConversationState; @@ -133,6 +134,7 @@ use crate::auth::AuthError; use crate::auth::builder_id::is_idc_user; use crate::cli::agent::Agents; use crate::cli::chat::cli::SlashCommand; +use crate::cli::chat::cli::model::default_model_id; use crate::cli::chat::cli::prompts::{ GetPromptError, PromptsSubcommand, @@ -312,36 +314,22 @@ impl ChatArgs { }; // If modelId is specified, verify it exists before starting the chat - // Otherwise, CLI will use a default model when starting chat - let (models, default_model_opt) = get_available_models(os).await?; - let model_id: Option = if let Some(requested) = self.model.as_ref() { - let requested_lower = requested.to_lowercase(); - if let Some(m) = models.iter().find(|m| { - m.model_name - .as_deref() - .is_some_and(|n| n.eq_ignore_ascii_case(&requested_lower)) - || m.model_id.eq_ignore_ascii_case(&requested_lower) - }) { - Some(m.model_id.clone()) - } else { - let available = models - .iter() - .map(|m| m.model_name.as_deref().unwrap_or(&m.model_id)) - .collect::>() - .join(", "); - bail!("Model '{}' does not exist. Available models: {}", requested, available); - } - } else if let Some(saved) = os.database.settings.get_string(Setting::ChatDefaultModel) { - if let Some(m) = models.iter().find(|m| { - m.model_name.as_deref().is_some_and(|n| n.eq_ignore_ascii_case(&saved)) - || m.model_id.eq_ignore_ascii_case(&saved) - }) { - Some(m.model_id.clone()) - } else { - Some(default_model_opt.model_id.clone()) + let model_options = get_model_options(os).await?; + let model_id: Option = if let Some(model_name) = self.model { + let model_name_lower = model_name.to_lowercase(); + match model_options.iter().find(|opt| opt.name == model_name_lower) { + Some(opt) => Some((opt.model_id).to_string()), + None => { + let available_names: Vec<&str> = model_options.iter().map(|opt| opt.name).collect(); + bail!( + "Model '{}' does not exist. Available models: {}", + model_name, + available_names.join(", ") + ); + }, } } else { - Some(default_model_opt.model_id.clone()) + None }; let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::>(); @@ -591,6 +579,28 @@ impl ChatSession { tool_config: HashMap, interactive: bool, ) -> Result { + let model_options = get_model_options(os).await?; + let valid_model_id = match model_id { + Some(id) => id, + None => { + let from_settings = os + .database + .settings + .get_string(Setting::ChatDefaultModel) + .and_then(|model_name| { + model_options + .iter() + .find(|opt| opt.name == model_name) + .map(|opt| opt.model_id.to_owned()) + }); + + match from_settings { + Some(id) => id, + None => default_model_id(os).await.to_owned(), + } + }, + }; + // Reload prior conversation let mut existing_conversation = false; let previous_conversation = std::env::current_dir() @@ -629,7 +639,9 @@ impl ChatSession { cs.enforce_tool_use_history_invariants(); cs }, - false => ConversationState::new(conversation_id, agents, tool_config, tool_manager, model_id, os).await, + false => { + ConversationState::new(conversation_id, agents, tool_config, tool_manager, Some(valid_model_id)).await + }, }; // Spawn a task for listening and broadcasting sigints. @@ -1184,14 +1196,13 @@ impl ChatSession { } self.stderr.flush()?; - if let Some(ref model_info) = self.conversation.model_info { - let (models, _default_model) = get_available_models(os).await?; - if let Some(model_option) = models.iter().find(|option| option.model_id == model_info.model_id) { - let display_name = model_option.model_name.as_deref().unwrap_or(&model_option.model_id); + if let Some(ref id) = self.conversation.model { + let model_options = get_model_options(os).await?; + if let Some(model_option) = model_options.iter().find(|option| option.model_id == *id) { execute!( self.stderr, style::SetForegroundColor(Color::Cyan), - style::Print(format!("🤖 You are chatting with {}\n", display_name)), + style::Print(format!("🤖 You are chatting with {}\n", model_option.name)), style::SetForegroundColor(Color::Reset), style::Print("\n") )?; @@ -2374,14 +2385,11 @@ impl ChatSession { for tool_use in tool_uses { let tool_use_id = tool_use.id.clone(); let tool_use_name = tool_use.name.clone(); - let mut tool_telemetry = ToolUseEventBuilder::new( - conv_id.clone(), - tool_use.id.clone(), - self.conversation.model_info.as_ref().map(|m| m.model_id.clone()), - ) - .set_tool_use_id(tool_use_id.clone()) - .set_tool_name(tool_use.name.clone()) - .utterance_id(self.conversation.message_id().map(|s| s.to_string())); + let mut tool_telemetry = + ToolUseEventBuilder::new(conv_id.clone(), tool_use.id.clone(), self.conversation.model.clone()) + .set_tool_use_id(tool_use_id.clone()) + .set_tool_name(tool_use.name.clone()) + .utterance_id(self.conversation.message_id().map(|s| s.to_string())); match self.conversation.tool_manager.get_tool_from_tool_use(tool_use) { Ok(mut tool) => { // Apply non-Q-generated context to tools @@ -2473,7 +2481,6 @@ impl ChatSession { } async fn retry_model_overload(&mut self, os: &mut Os) -> Result { - os.client.invalidate_model_cache().await; match select_model(os, self).await { Ok(Some(_)) => (), Ok(None) => {