Skip to content

Commit 4b13864

Browse files
authored
Merge pull request #157 from aws/evannliu/model-selection
feat: model selection for chat
2 parents 11be917 + 8ed7d01 commit 4b13864

File tree

15 files changed

+329
-8
lines changed

15 files changed

+329
-8
lines changed

crates/chat-cli/src/api_client/clients/client.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ impl Client {
7878
telemetry_event: TelemetryEvent,
7979
user_context: UserContext,
8080
telemetry_enabled: bool,
81+
model_id: Option<String>,
8182
) -> Result<(), ApiClientError> {
8283
match &self.inner {
8384
inner::Inner::Codewhisperer(client) => {
@@ -90,6 +91,7 @@ impl Client {
9091
false => OptOutPreference::OptOut,
9192
})
9293
.set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone()))
94+
.set_model_id(model_id)
9395
.send()
9496
.await?;
9597
Ok(())
@@ -159,6 +161,7 @@ mod tests {
159161
.build()
160162
.unwrap(),
161163
false,
164+
Some("model".to_owned()),
162165
)
163166
.await
164167
.unwrap();

crates/chat-cli/src/api_client/clients/streaming_client.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ impl StreamingClient {
109109

110110
match &self.inner {
111111
inner::Inner::Codewhisperer(client) => {
112+
let model_id_opt: Option<String> = user_input_message.model_id.clone();
112113
let conversation_state = amzn_codewhisperer_streaming_client::types::ConversationState::builder()
113114
.set_conversation_id(conversation_id)
114115
.current_message(
@@ -140,10 +141,22 @@ impl StreamingClient {
140141
&& err.meta().message() == Some("Input is too long."))
141142
});
142143

144+
let is_model_unavailable = model_id_opt.is_some()
145+
&& e.raw_response().is_some_and(|resp| resp.status().as_u16() == 500)
146+
&& e.as_service_error().is_some_and(|err| {
147+
err.meta().message()
148+
== Some("Encountered unexpectedly high load when processing the request, please try again.")
149+
});
143150
if is_quota_breach {
144151
Err(ApiClientError::QuotaBreach("quota has reached its limit"))
145152
} else if is_context_window_overflow {
146153
Err(ApiClientError::ContextWindowOverflow)
154+
} else if is_model_unavailable {
155+
let request_id = e
156+
.as_service_error()
157+
.and_then(|err| err.meta().request_id())
158+
.map(|s| s.to_string());
159+
Err(ApiClientError::ModelOverloadedError { request_id })
147160
} else {
148161
Err(e.into())
149162
}
@@ -235,6 +248,7 @@ mod tests {
235248
content: "Hello".into(),
236249
user_input_message_context: None,
237250
user_intent: None,
251+
model_id: Some("model".to_owned()),
238252
},
239253
history: None,
240254
})
@@ -261,13 +275,15 @@ mod tests {
261275
content: "How about rustc?".into(),
262276
user_input_message_context: None,
263277
user_intent: None,
278+
model_id: Some("model".to_owned()),
264279
},
265280
history: Some(vec![
266281
ChatMessage::UserInputMessage(UserInputMessage {
267282
images: None,
268283
content: "What language is the linux kernel written in, and who wrote it?".into(),
269284
user_input_message_context: None,
270285
user_intent: None,
286+
model_id: None,
271287
}),
272288
ChatMessage::AssistantResponseMessage(AssistantResponseMessage {
273289
content: "It is written in C by Linus Torvalds.".into(),

crates/chat-cli/src/api_client/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ pub enum ApiClientError {
6868

6969
#[error(transparent)]
7070
AuthError(#[from] AuthError),
71+
72+
#[error(
73+
"The model you've selected is temporarily unavailable. Please use '/model' to select a different model and try again."
74+
)]
75+
ModelOverloadedError { request_id: Option<String> },
7176
}
7277

7378
impl ReasonCode for ApiClientError {

crates/chat-cli/src/api_client/model.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ pub struct UserInputMessage {
859859
pub user_input_message_context: Option<UserInputMessageContext>,
860860
pub user_intent: Option<UserIntent>,
861861
pub images: Option<Vec<ImageBlock>>,
862+
pub model_id: Option<String>,
862863
}
863864

864865
impl From<UserInputMessage> for amzn_codewhisperer_streaming_client::types::UserInputMessage {
@@ -868,6 +869,7 @@ impl From<UserInputMessage> for amzn_codewhisperer_streaming_client::types::User
868869
.set_images(value.images.map(|images| images.into_iter().map(Into::into).collect()))
869870
.set_user_input_message_context(value.user_input_message_context.map(Into::into))
870871
.set_user_intent(value.user_intent.map(Into::into))
872+
.set_model_id(value.model_id)
871873
.origin(amzn_codewhisperer_streaming_client::types::Origin::Cli)
872874
.build()
873875
.expect("Failed to build UserInputMessage")
@@ -881,6 +883,7 @@ impl From<UserInputMessage> for amzn_qdeveloper_streaming_client::types::UserInp
881883
.set_images(value.images.map(|images| images.into_iter().map(Into::into).collect()))
882884
.set_user_input_message_context(value.user_input_message_context.map(Into::into))
883885
.set_user_intent(value.user_intent.map(Into::into))
886+
.set_model_id(value.model_id)
884887
.origin(amzn_qdeveloper_streaming_client::types::Origin::Cli)
885888
.build()
886889
.expect("Failed to build UserInputMessage")
@@ -976,6 +979,7 @@ mod tests {
976979
})]),
977980
}),
978981
user_intent: Some(UserIntent::ApplyCommonBestPractices),
982+
model_id: Some("model id".to_string()),
979983
};
980984

981985
let codewhisper_input =
@@ -989,6 +993,7 @@ mod tests {
989993
content: "test content".to_string(),
990994
user_input_message_context: None,
991995
user_intent: None,
996+
model_id: Some("model id".to_string()),
992997
};
993998

994999
let codewhisper_minimal =

crates/chat-cli/src/cli/chat/command.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub enum Command {
5959
force: bool,
6060
},
6161
Mcp,
62+
Model,
6263
}
6364

6465
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -839,6 +840,7 @@ impl Command {
839840
Self::Save { path, force }
840841
},
841842
"mcp" => Self::Mcp,
843+
"model" => Self::Model,
842844
unknown_command => {
843845
let looks_like_path = {
844846
let after_slash_command_str = parts[1..].join(" ");

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ pub struct ConversationState {
105105
latest_summary: Option<String>,
106106
#[serde(skip)]
107107
pub updates: Option<SharedWriter>,
108+
/// Model explicitly selected by the user in this conversation state via `/model`.
109+
#[serde(default, skip_serializing_if = "Option::is_none")]
110+
pub current_model_id: Option<String>,
108111
}
109112

110113
impl ConversationState {
@@ -115,6 +118,7 @@ impl ConversationState {
115118
profile: Option<String>,
116119
updates: Option<SharedWriter>,
117120
tool_manager: ToolManager,
121+
current_model_id: Option<String>,
118122
) -> Self {
119123
// Initialize context manager
120124
let context_manager = match ContextManager::new(ctx, None).await {
@@ -157,6 +161,7 @@ impl ConversationState {
157161
context_message_length: None,
158162
latest_summary: None,
159163
updates,
164+
current_model_id,
160165
}
161166
}
162167

@@ -528,6 +533,7 @@ impl ConversationState {
528533
context_messages,
529534
dropped_context_files,
530535
tools: &self.tools,
536+
model_id: self.current_model_id.as_deref(),
531537
}
532538
}
533539

@@ -599,6 +605,7 @@ impl ConversationState {
599605
user_input_message_context: None,
600606
user_intent: None,
601607
images: None,
608+
model_id: self.current_model_id.clone(),
602609
};
603610

604611
// If the last message contains tool uses, then add cancelled tool results to the summary
@@ -830,6 +837,7 @@ pub struct BackendConversationStateImpl<'a, T, U> {
830837
pub context_messages: U,
831838
pub dropped_context_files: Vec<(String, String)>,
832839
pub tools: &'a HashMap<ToolOrigin, Vec<Tool>>,
840+
pub model_id: Option<&'a str>,
833841
}
834842

835843
impl
@@ -846,6 +854,7 @@ impl
846854
.cloned()
847855
.map(UserMessage::into_user_input_message)
848856
.ok_or(eyre::eyre!("next user message is not set"))?;
857+
user_input_message.model_id = self.model_id.map(str::to_string);
849858
if let Some(ctx) = user_input_message.user_input_message_context.as_mut() {
850859
ctx.tools = Some(self.tools.values().flatten().cloned().collect::<Vec<_>>());
851860
}
@@ -1059,6 +1068,7 @@ mod tests {
10591068
None,
10601069
None,
10611070
tool_manager,
1071+
None,
10621072
)
10631073
.await;
10641074

@@ -1089,6 +1099,7 @@ mod tests {
10891099
None,
10901100
None,
10911101
tool_manager.clone(),
1102+
None,
10921103
)
10931104
.await;
10941105
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1120,6 +1131,7 @@ mod tests {
11201131
None,
11211132
None,
11221133
tool_manager.clone(),
1134+
None,
11231135
)
11241136
.await;
11251137
conversation_state.set_next_user_message("start".to_string()).await;
@@ -1165,6 +1177,7 @@ mod tests {
11651177
None,
11661178
None,
11671179
tool_manager,
1180+
None,
11681181
)
11691182
.await;
11701183

@@ -1235,6 +1248,7 @@ mod tests {
12351248
None,
12361249
Some(SharedWriter::stdout()),
12371250
tool_manager,
1251+
None,
12381252
)
12391253
.await;
12401254

crates/chat-cli/src/cli/chat/message.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ impl UserMessage {
128128
..Default::default()
129129
}),
130130
user_intent: None,
131+
model_id: None,
131132
}
132133
}
133134

@@ -158,6 +159,7 @@ impl UserMessage {
158159
..Default::default()
159160
}),
160161
user_intent: None,
162+
model_id: None,
161163
}
162164
}
163165

0 commit comments

Comments
 (0)