Skip to content

Commit a57a0b0

Browse files
committed
cooperate with default model in response
1 parent 5c31dcb commit a57a0b0

File tree

3 files changed

+43
-53
lines changed

3 files changed

+43
-53
lines changed

crates/chat-cli/src/api_client/mod.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -254,21 +254,24 @@ impl ApiClient {
254254
));
255255
}
256256

257-
// todo yifan: add default_model once API is ready
258257
let mut models = Vec::new();
259-
let default_model = None;
258+
let mut default_model = None;
260259
let request = self
261260
.client
262261
.list_available_models()
263262
.set_origin(Some(Cli))
264263
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()));
265264
let mut paginator = request.into_paginator().send();
266265

267-
while let Some(models_output) = paginator.next().await {
268-
models.extend(models_output?.models().iter().cloned());
269-
// if default_model.is_none() && output.default_model().is_some() {
270-
// default_model = output.default_model().cloned();
271-
// }
266+
while let Some(result) = paginator.next().await {
267+
let models_output = result?;
268+
models.extend(models_output.models().iter().cloned());
269+
270+
if default_model.is_none() {
271+
if let Some(model) = models_output.default_model().cloned() {
272+
default_model = Some(model);
273+
}
274+
}
272275
}
273276

274277
Ok((models, default_model))

crates/chat-cli/src/cli/chat/cli/model.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,6 @@ pub async fn select_model(os: &mut Os, session: &mut ChatSession) -> Result<Opti
112112
/// Returns Claude 3.7 for: Amazon IDC users, FRA region users
113113
/// Returns Claude 4.0 for: Builder ID users, other regions
114114
pub async fn default_model_id(os: &Os) -> String {
115-
if let Ok((_, Some(default_model))) = os.client.list_available_models().await {
116-
return default_model.model_id().to_string();
117-
}
118115
// Check FRA region first
119116
if let Ok(Some(profile)) = os.database.get_auth_profile() {
120117
if profile.arn.split(':').nth(3) == Some("eu-central-1") {

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -288,25 +288,40 @@ impl ChatArgs {
288288
};
289289

290290
// If modelId is specified, verify it exists before starting the chat
291-
let model_id: Option<String> = if let Some(requested_model_id) = self.model {
292-
let requested_model_id_lower = requested_model_id.to_lowercase();
293-
let (models, _) = os.client.list_available_models().await?;
294-
match models.iter().find(|opt| opt.model_id == requested_model_id_lower) {
295-
Some(opt) => Some(opt.model_id.clone()),
296-
None => {
297-
let available_names: Vec<String> = models
298-
.iter()
299-
.map(|opt| get_display_name(opt.model_id()).to_string())
300-
.collect();
301-
bail!(
302-
"Model '{}' does not exist. Available models: {}",
303-
requested_model_id,
304-
available_names.join(", ")
305-
);
306-
},
291+
// Otherwise, CLI will use a default model when starting chat
292+
let (models, default_model_opt) = os.client.list_available_models().await?;
293+
let model_id: Option<String> = if let Some(requested) = self.model.as_ref() {
294+
let requested_lower = requested.to_lowercase();
295+
if let Some(m) = models
296+
.iter()
297+
.find(|m| m.model_id.eq_ignore_ascii_case(&requested_lower))
298+
{
299+
Some(m.model_id.clone())
300+
} else {
301+
let available = models
302+
.iter()
303+
.map(|m| get_display_name(m.model_id()).to_string())
304+
.collect::<Vec<_>>()
305+
.join(", ");
306+
bail!("Model '{}' does not exist. Available models: {}", requested, available);
307307
}
308+
} else if let Some(saved) = os
309+
.database
310+
.settings
311+
.get_string(Setting::ChatDefaultModel)
312+
.and_then(|name| {
313+
models
314+
.iter()
315+
.find(|m| get_display_name(m.model_id()) == name)
316+
.map(|m| m.model_id.clone())
317+
})
318+
{
319+
Some(saved)
320+
} else if let Some(default_model) = default_model_opt {
321+
Some(default_model.model_id().to_owned())
308322
} else {
309-
None
323+
// should not use this fallback method when service return a required default model in response
324+
Some(default_model_id(os).await)
310325
};
311326

312327
let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::<Option<String>>();
@@ -556,29 +571,6 @@ impl ChatSession {
556571
tool_config: HashMap<String, ToolSpec>,
557572
interactive: bool,
558573
) -> Result<Self> {
559-
let (models, _default_model) = os.client.list_available_models().await?;
560-
561-
let valid_model_id = match model_id {
562-
Some(id) => id,
563-
None => {
564-
let from_settings = os
565-
.database
566-
.settings
567-
.get_string(Setting::ChatDefaultModel)
568-
.and_then(|model_id| {
569-
models
570-
.iter()
571-
.find(|opt| get_display_name(opt.model_id()) == model_id)
572-
.map(|opt| opt.model_id.to_owned())
573-
});
574-
575-
match from_settings {
576-
Some(id) => id,
577-
None => default_model_id(os).await.to_owned(),
578-
}
579-
},
580-
};
581-
582574
// Reload prior conversation
583575
let mut existing_conversation = false;
584576
let previous_conversation = std::env::current_dir()
@@ -617,9 +609,7 @@ impl ChatSession {
617609
cs.enforce_tool_use_history_invariants();
618610
cs
619611
},
620-
false => {
621-
ConversationState::new(conversation_id, agents, tool_config, tool_manager, Some(valid_model_id)).await
622-
},
612+
false => ConversationState::new(conversation_id, agents, tool_config, tool_manager, model_id).await,
623613
};
624614

625615
// Spawn a task for listening and broadcasting sigints.

0 commit comments

Comments
 (0)