Skip to content

Commit 33c2537

Browse files
committed
change the default model to be a required field
1 parent 4bcd58e commit 33c2537

File tree

6 files changed

+18
-55
lines changed

6 files changed

+18
-55
lines changed

crates/chat-cli/src/api_client/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ pub enum ApiClientError {
9797

9898
#[error(transparent)]
9999
ListAvailableModelsError(#[from] SdkError<ListAvailableModelsError, HttpResponse>),
100+
101+
#[error("No default model found in the ListAvailableModels API response")]
102+
DefaultModelNotFound,
100103
}
101104

102105
impl ApiClientError {
@@ -121,6 +124,7 @@ impl ApiClientError {
121124
Self::MonthlyLimitReached { status_code } => *status_code,
122125
Self::Credentials(_e) => None,
123126
Self::ListAvailableModelsError(e) => sdk_status_code(e),
127+
Self::DefaultModelNotFound => None,
124128
}
125129
}
126130
}
@@ -147,6 +151,7 @@ impl ReasonCode for ApiClientError {
147151
Self::MonthlyLimitReached { .. } => "MonthlyLimitReached".to_string(),
148152
Self::Credentials(_) => "CredentialsError".to_string(),
149153
Self::ListAvailableModelsError(e) => sdk_error_code(e),
154+
Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(),
150155
}
151156
}
152157
}

crates/chat-cli/src/api_client/mod.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-opto
6969
// TODO(bskiser): confirm timeout is updated to an appropriate value?
7070
const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5);
7171

72-
type ModelListResult = (Vec<Model>, Option<Model>);
72+
type ModelListResult = (Vec<Model>, Model);
7373
type ModelCache = Arc<RwLock<Option<ModelListResult>>>;
7474

7575
#[derive(Clone, Debug)]
@@ -241,7 +241,7 @@ impl ApiClient {
241241
Ok(profiles)
242242
}
243243

244-
pub async fn list_available_models(&self) -> Result<(Vec<Model>, Option<Model>), ApiClientError> {
244+
pub async fn list_available_models(&self) -> Result<(Vec<Model>, Model), ApiClientError> {
245245
if cfg!(test) {
246246
return Ok((
247247
vec![
@@ -251,13 +251,11 @@ impl ApiClient {
251251
.build()
252252
.unwrap(),
253253
],
254-
Some(
255-
Model::builder()
256-
.model_id("model-1")
257-
.description("Test Model 1")
258-
.build()
259-
.unwrap(),
260-
),
254+
Model::builder()
255+
.model_id("model-1")
256+
.description("Test Model 1")
257+
.build()
258+
.unwrap(),
261259
));
262260
}
263261

@@ -278,11 +276,11 @@ impl ApiClient {
278276
default_model = Some(models_output.default_model().clone());
279277
}
280278
}
281-
279+
let default_model = default_model.ok_or_else(|| ApiClientError::DefaultModelNotFound)?;
282280
Ok((models, default_model))
283281
}
284282

285-
pub async fn list_available_models_cached(&self) -> Result<(Vec<Model>, Option<Model>), ApiClientError> {
283+
pub async fn list_available_models_cached(&self) -> Result<(Vec<Model>, Model), ApiClientError> {
286284
{
287285
let cache = self.model_cache.read().await;
288286
if let Some(cached) = cache.as_ref() {
@@ -306,7 +304,7 @@ impl ApiClient {
306304
tracing::info!("Model cache invalidated");
307305
}
308306

309-
pub async fn get_available_models(&self, region: &str) -> Result<(Vec<Model>, Option<Model>), ApiClientError> {
307+
pub async fn get_available_models(&self, region: &str) -> Result<(Vec<Model>, Model), ApiClientError> {
310308
let (mut models, default_model) = self.list_available_models_cached().await?;
311309

312310
if region == "us-east-1" {

crates/chat-cli/src/auth/builder_id.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -477,12 +477,6 @@ impl BuilderIdToken {
477477
Some(_) => TokenType::IamIdentityCenter,
478478
}
479479
}
480-
481-
/// Check if the token is for the internal amzn start URL (`https://amzn.awsapps.com/start`),
482-
/// this implies the user will use midway for private specs
483-
pub fn is_amzn_user(&self) -> bool {
484-
matches!(&self.start_url, Some(url) if url == AMZN_START_URL)
485-
}
486480
}
487481

488482
pub enum PollCreateToken {

crates/chat-cli/src/auth/consts.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,5 @@ pub(crate) const CLIENT_TYPE: &str = "public";
2121
// The start URL for public builder ID users
2222
pub const START_URL: &str = "https://view.awsapps.com/start";
2323

24-
// The start URL for internal amzn users
25-
pub const AMZN_START_URL: &str = "https://amzn.awsapps.com/start";
26-
2724
pub(crate) const DEVICE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code";
2825
pub(crate) const REFRESH_GRANT_TYPE: &str = "refresh_token";

crates/chat-cli/src/cli/chat/cli/model.rs

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ use crossterm::{
1111
use dialoguer::Select;
1212

1313
use crate::api_client::Endpoint;
14-
use crate::auth::builder_id::{
15-
BuilderIdToken,
16-
TokenType,
17-
};
1814
use crate::cli::chat::{
1915
ChatError,
2016
ChatSession,
@@ -107,31 +103,8 @@ pub async fn select_model(os: &Os, session: &mut ChatSession) -> Result<Option<C
107103
}))
108104
}
109105

110-
/// Returns a default model id to use if none has been otherwise provided.
111-
///
112-
/// Returns Claude 3.7 for: Amazon IDC users, FRA region users
113-
/// Returns Claude 4.0 for: Builder ID users, other regions
114-
pub async fn default_model_id(os: &Os) -> String {
115-
// Check FRA region first
116-
if let Ok(Some(profile)) = os.database.get_auth_profile() {
117-
if profile.arn.split(':').nth(3) == Some("eu-central-1") {
118-
return "claude-3.7-sonnet".to_string();
119-
}
120-
}
121-
122-
// Check if Amazon IDC user
123-
if let Ok(Some(token)) = BuilderIdToken::load(&os.database).await {
124-
if matches!(token.token_type(), TokenType::IamIdentityCenter) && token.is_amzn_user() {
125-
return "claude-3.7-sonnet".to_string();
126-
}
127-
}
128-
129-
// Default to 4.0
130-
"claude-4-sonnet".to_string()
131-
}
132-
133106
/// Get available models with caching support
134-
pub async fn get_available_models(os: &Os) -> Result<(Vec<Model>, Option<Model>), ChatError> {
107+
pub async fn get_available_models(os: &Os) -> Result<(Vec<Model>, Model), ChatError> {
135108
let endpoint = Endpoint::configured_value(&os.database);
136109
let region = endpoint.region().as_ref();
137110

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ use crate::auth::AuthError;
132132
use crate::auth::builder_id::is_idc_user;
133133
use crate::cli::agent::Agents;
134134
use crate::cli::chat::cli::SlashCommand;
135-
use crate::cli::chat::cli::model::default_model_id;
136135
use crate::cli::chat::cli::prompts::{
137136
GetPromptError,
138137
PromptsSubcommand,
@@ -313,7 +312,7 @@ impl ChatArgs {
313312

314313
// If modelId is specified, verify it exists before starting the chat
315314
// Otherwise, CLI will use a default model when starting chat
316-
let (models, default_model_opt) = os.client.list_available_models_cached().await?;
315+
let (models, default_model_opt) = get_available_models(os).await?;
317316
let model_id: Option<String> = if let Some(requested) = self.model.as_ref() {
318317
let requested_lower = requested.to_lowercase();
319318
if let Some(m) = models
@@ -337,11 +336,8 @@ impl ChatArgs {
337336
})
338337
{
339338
Some(saved)
340-
} else if let Some(default_model) = default_model_opt {
341-
Some(default_model.model_id().to_owned())
342339
} else {
343-
// should not use this fallback method when service return a required default model in response
344-
Some(default_model_id(os).await)
340+
Some(default_model_opt.model_id().to_owned())
345341
};
346342

347343
let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::<Option<String>>();

0 commit comments

Comments
 (0)