Skip to content

Commit 90476ed

Browse files
authored
Merge pull request #135 from aws/evannliu/model-selection
feat: model selection for chat merge for getting a bugbash build
2 parents f768fb4 + 296dd89 commit 90476ed

File tree

15 files changed

+455
-7
lines changed

15 files changed

+455
-7
lines changed

crates/chat-cli/src/api_client/clients/client.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ impl Client {
7878
telemetry_event: TelemetryEvent,
7979
user_context: UserContext,
8080
telemetry_enabled: bool,
81+
model_id: Option<String>,
8182
) -> Result<(), ApiClientError> {
8283
match &self.inner {
8384
inner::Inner::Codewhisperer(client) => {
@@ -90,6 +91,7 @@ impl Client {
9091
false => OptOutPreference::OptOut,
9192
})
9293
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()))
94+
.set_model_id(model_id)
9395
.send()
9496
.await;
9597
Ok(())
@@ -159,6 +161,7 @@ mod tests {
159161
.build()
160162
.unwrap(),
161163
false,
164+
Some("model".to_owned()),
162165
)
163166
.await
164167
.unwrap();

crates/chat-cli/src/api_client/clients/streaming_client.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ impl StreamingClient {
139139

140140
match &self.inner {
141141
inner::Inner::Codewhisperer(client) => {
142+
let model_id_opt: Option<String> = user_input_message.model_id.clone();
142143
let conversation_state = amzn_codewhisperer_streaming_client::types::ConversationState::builder()
143144
.set_conversation_id(conversation_id)
144145
.current_message(
@@ -170,10 +171,22 @@ impl StreamingClient {
170171
&& err.meta().message() == Some("Input is too long."))
171172
});
172173

174+
let is_model_unavailable = model_id_opt.is_some()
175+
&& e.raw_response().is_some_and(|resp| resp.status().as_u16() == 500)
176+
&& e.as_service_error().is_some_and(|err| {
177+
err.meta().message()
178+
== Some("Encountered unexpectedly high load when processing the request, please try again.")
179+
});
173180
if is_quota_breach {
174181
Err(ApiClientError::QuotaBreach("quota has reached its limit"))
175182
} else if is_context_window_overflow {
176183
Err(ApiClientError::ContextWindowOverflow)
184+
} else if is_model_unavailable {
185+
let request_id = e
186+
.as_service_error()
187+
.and_then(|err| err.meta().request_id())
188+
.map(|s| s.to_string());
189+
Err(ApiClientError::ModelOverloadedError(request_id))
177190
} else {
178191
Err(e.into())
179192
}
@@ -291,6 +304,7 @@ mod tests {
291304
content: "Hello".into(),
292305
user_input_message_context: None,
293306
user_intent: None,
307+
model_id: Some("model".to_owned()),
294308
},
295309
history: None,
296310
})
@@ -317,13 +331,15 @@ mod tests {
317331
content: "How about rustc?".into(),
318332
user_input_message_context: None,
319333
user_intent: None,
334+
model_id: Some("model".to_owned()),
320335
},
321336
history: Some(vec![
322337
ChatMessage::UserInputMessage(UserInputMessage {
323338
images: None,
324339
content: "What language is the linux kernel written in, and who wrote it?".into(),
325340
user_input_message_context: None,
326341
user_intent: None,
342+
model_id: None,
327343
}),
328344
ChatMessage::AssistantResponseMessage(AssistantResponseMessage {
329345
content: "It is written in C by Linus Torvalds.".into(),

crates/chat-cli/src/api_client/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ pub enum ApiClientError {
6565

6666
#[error(transparent)]
6767
AuthError(#[from] AuthError),
68+
69+
#[error(
70+
"The model you've selected is temporarily unavailable. Please use '/model' to select a different model and try again."
71+
)]
72+
ModelOverloadedError(Option<String>),
6873
}
6974

7075
#[cfg(test)]

crates/chat-cli/src/api_client/model.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ pub struct UserInputMessage {
859859
pub user_input_message_context: Option<UserInputMessageContext>,
860860
pub user_intent: Option<UserIntent>,
861861
pub images: Option<Vec<ImageBlock>>,
862+
pub model_id: Option<String>,
862863
}
863864

864865
impl From<UserInputMessage> for amzn_codewhisperer_streaming_client::types::UserInputMessage {
@@ -868,6 +869,7 @@ impl From<UserInputMessage> for amzn_codewhisperer_streaming_client::types::User
868869
.set_images(value.images.map(|images| images.into_iter().map(Into::into).collect()))
869870
.set_user_input_message_context(value.user_input_message_context.map(Into::into))
870871
.set_user_intent(value.user_intent.map(Into::into))
872+
.set_model_id(value.model_id)
871873
.origin(amzn_codewhisperer_streaming_client::types::Origin::Cli)
872874
.build()
873875
.expect("Failed to build UserInputMessage")
@@ -881,6 +883,7 @@ impl From<UserInputMessage> for amzn_qdeveloper_streaming_client::types::UserInp
881883
.set_images(value.images.map(|images| images.into_iter().map(Into::into).collect()))
882884
.set_user_input_message_context(value.user_input_message_context.map(Into::into))
883885
.set_user_intent(value.user_intent.map(Into::into))
886+
.set_model_id(value.model_id)
884887
.origin(amzn_qdeveloper_streaming_client::types::Origin::Cli)
885888
.build()
886889
.expect("Failed to build UserInputMessage")
@@ -976,6 +979,7 @@ mod tests {
976979
})]),
977980
}),
978981
user_intent: Some(UserIntent::ApplyCommonBestPractices),
982+
model_id: Some("model id".to_string()),
979983
};
980984

981985
let codewhisper_input =
@@ -989,6 +993,7 @@ mod tests {
989993
content: "test content".to_string(),
990994
user_input_message_context: None,
991995
user_intent: None,
996+
model_id: Some("model id".to_string()),
992997
};
993998

994999
let codewhisper_minimal =

crates/chat-cli/src/cli/chat/cli.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
use std::collections::HashMap;
2+
3+
use clap::{
4+
Args,
5+
Parser,
6+
Subcommand,
7+
ValueEnum,
8+
};
9+
10+
#[derive(Debug, Clone, PartialEq, Eq, Default, Parser)]
11+
pub struct Chat {
12+
/// (Deprecated, use --trust-all-tools) Enabling this flag allows the model to execute
13+
/// all commands without first accepting them.
14+
#[arg(short, long, hide = true)]
15+
pub accept_all: bool,
16+
/// Print the first response to STDOUT without interactive mode. This will fail if the
17+
/// prompt requests permissions to use a tool, unless --trust-all-tools is also used.
18+
#[arg(long)]
19+
pub no_interactive: bool,
20+
/// Resumes the previous conversation from this directory.
21+
#[arg(short, long)]
22+
pub resume: bool,
23+
/// The first question to ask
24+
pub input: Option<String>,
25+
/// Context profile to use
26+
#[arg(long = "profile")]
27+
pub profile: Option<String>,
28+
/// Current model to use
29+
#[arg(long = "model")]
30+
pub model: Option<String>,
31+
/// Allows the model to use any tool to run commands without asking for confirmation.
32+
#[arg(long)]
33+
pub trust_all_tools: bool,
34+
/// Trust only this set of tools. Example: trust some tools:
35+
/// '--trust-tools=fs_read,fs_write', trust no tools: '--trust-tools='
36+
#[arg(long, value_delimiter = ',', value_name = "TOOL_NAMES")]
37+
pub trust_tools: Option<Vec<String>>,
38+
}
39+
40+
#[derive(Debug, Clone, PartialEq, Eq, Subcommand)]
41+
pub enum Mcp {
42+
/// Add or replace a configured server
43+
Add(McpAdd),
44+
/// Remove a server from the MCP configuration
45+
#[command(alias = "rm")]
46+
Remove(McpRemove),
47+
/// List configured servers
48+
List(McpList),
49+
/// Import a server configuration from another file
50+
Import(McpImport),
51+
/// Get the status of a configured server
52+
Status {
53+
#[arg(long)]
54+
name: String,
55+
},
56+
}
57+
58+
#[derive(Debug, Clone, PartialEq, Eq, Args)]
59+
pub struct McpAdd {
60+
/// Name for the server
61+
#[arg(long)]
62+
pub name: String,
63+
/// The command used to launch the server
64+
#[arg(long)]
65+
pub command: String,
66+
/// Where to add the server to.
67+
#[arg(long, value_enum)]
68+
pub scope: Option<Scope>,
69+
/// Environment variables to use when launching the server
70+
#[arg(long, value_parser = parse_env_vars)]
71+
pub env: Vec<HashMap<String, String>>,
72+
/// Server launch timeout, in milliseconds
73+
#[arg(long)]
74+
pub timeout: Option<u64>,
75+
/// Overwrite an existing server with the same name
76+
#[arg(long, default_value_t = false)]
77+
pub force: bool,
78+
}
79+
80+
#[derive(Debug, Clone, PartialEq, Eq, Args)]
81+
pub struct McpRemove {
82+
#[arg(long)]
83+
pub name: String,
84+
#[arg(long, value_enum)]
85+
pub scope: Option<Scope>,
86+
}
87+
88+
#[derive(Debug, Clone, PartialEq, Eq, Args)]
89+
pub struct McpList {
90+
#[arg(value_enum)]
91+
pub scope: Option<Scope>,
92+
#[arg(long, hide = true)]
93+
pub profile: Option<String>,
94+
}
95+
96+
#[derive(Debug, Clone, PartialEq, Eq, Args)]
97+
pub struct McpImport {
98+
#[arg(long)]
99+
pub file: String,
100+
#[arg(value_enum)]
101+
pub scope: Option<Scope>,
102+
/// Overwrite an existing server with the same name
103+
#[arg(long, default_value_t = false)]
104+
pub force: bool,
105+
}
106+
107+
#[derive(Debug, Copy, Clone, PartialEq, Eq, ValueEnum)]
108+
pub enum Scope {
109+
Workspace,
110+
Global,
111+
}
112+
113+
impl std::fmt::Display for Scope {
114+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115+
match self {
116+
Scope::Workspace => write!(f, "workspace"),
117+
Scope::Global => write!(f, "global"),
118+
}
119+
}
120+
}
121+
122+
#[derive(Debug)]
123+
struct EnvVarParseError(String);
124+
125+
impl std::fmt::Display for EnvVarParseError {
126+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127+
write!(f, "Failed to parse environment variables: {}", self.0)
128+
}
129+
}
130+
131+
impl std::error::Error for EnvVarParseError {}
132+
133+
fn parse_env_vars(arg: &str) -> Result<HashMap<String, String>, EnvVarParseError> {
134+
let mut vars = HashMap::new();
135+
136+
for pair in arg.split(",") {
137+
match pair.split_once('=') {
138+
Some((key, value)) => {
139+
vars.insert(key.trim().to_string(), value.trim().to_string());
140+
},
141+
None => {
142+
return Err(EnvVarParseError(format!(
143+
"Invalid environment variable '{}'. Expected 'name=value'",
144+
pair
145+
)));
146+
},
147+
}
148+
}
149+
150+
Ok(vars)
151+
}

crates/chat-cli/src/cli/chat/command.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub enum Command {
5959
force: bool,
6060
},
6161
Mcp,
62+
Model,
6263
}
6364

6465
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -839,6 +840,7 @@ impl Command {
839840
Self::Save { path, force }
840841
},
841842
"mcp" => Self::Mcp,
843+
"model" => Self::Model,
842844
unknown_command => {
843845
let looks_like_path = {
844846
let after_slash_command_str = parts[1..].join(" ");

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ pub struct ConversationState {
105105
latest_summary: Option<String>,
106106
#[serde(skip)]
107107
pub updates: Option<SharedWriter>,
108+
/// Model explicitly selected by the user in this conversation state via `/model`. (`None` ==
109+
/// auto)
110+
#[serde(default, skip_serializing_if = "Option::is_none")]
111+
pub current_model_id: Option<String>,
108112
}
109113

110114
impl ConversationState {
@@ -115,6 +119,7 @@ impl ConversationState {
115119
profile: Option<String>,
116120
updates: Option<SharedWriter>,
117121
tool_manager: ToolManager,
122+
current_model_id: Option<String>,
118123
) -> Self {
119124
// Initialize context manager
120125
let context_manager = match ContextManager::new(ctx, None).await {
@@ -157,6 +162,7 @@ impl ConversationState {
157162
context_message_length: None,
158163
latest_summary: None,
159164
updates,
165+
current_model_id,
160166
}
161167
}
162168

@@ -528,6 +534,7 @@ impl ConversationState {
528534
context_messages,
529535
dropped_context_files,
530536
tools: &self.tools,
537+
model_id: self.current_model_id.as_deref(),
531538
}
532539
}
533540

@@ -599,6 +606,7 @@ impl ConversationState {
599606
user_input_message_context: None,
600607
user_intent: None,
601608
images: None,
609+
model_id: self.current_model_id.clone(),
602610
};
603611

604612
// If the last message contains tool uses, then add cancelled tool results to the summary
@@ -830,6 +838,7 @@ pub struct BackendConversationStateImpl<'a, T, U> {
830838
pub context_messages: U,
831839
pub dropped_context_files: Vec<(String, String)>,
832840
pub tools: &'a HashMap<ToolOrigin, Vec<Tool>>,
841+
pub model_id: Option<&'a str>,
833842
}
834843

835844
impl
@@ -846,6 +855,7 @@ impl
846855
.cloned()
847856
.map(UserMessage::into_user_input_message)
848857
.ok_or(eyre::eyre!("next user message is not set"))?;
858+
user_input_message.model_id = self.model_id.map(str::to_string);
849859
if let Some(ctx) = user_input_message.user_input_message_context.as_mut() {
850860
ctx.tools = Some(self.tools.values().flatten().cloned().collect::<Vec<_>>());
851861
}
@@ -1059,6 +1069,7 @@ mod tests {
10591069
None,
10601070
None,
10611071
tool_manager,
1072+
None,
10621073
)
10631074
.await;
10641075

@@ -1089,6 +1100,7 @@ mod tests {
10891100
None,
10901101
None,
10911102
tool_manager.clone(),
1103+
None,
10921104
)
10931105
.await;
10941106
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1120,6 +1132,7 @@ mod tests {
11201132
None,
11211133
None,
11221134
tool_manager.clone(),
1135+
None,
11231136
)
11241137
.await;
11251138
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1165,6 +1178,7 @@ mod tests {
11651178
None,
11661179
None,
11671180
tool_manager,
1181+
None,
11681182
)
11691183
.await;
11701184

@@ -1235,6 +1249,7 @@ mod tests {
12351249
None,
12361250
Some(SharedWriter::stdout()),
12371251
tool_manager,
1252+
None,
12381253
)
12391254
.await;
12401255

0 commit comments

Comments
 (0)