Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/chat-cli/src/api_client/clients/streaming_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ mod tests {
.send_message(ConversationState {
conversation_id: None,
user_input_message: UserInputMessage {
images: None,
content: "Hello".into(),
user_input_message_context: None,
user_intent: None,
Expand All @@ -315,12 +316,14 @@ mod tests {
.send_message(ConversationState {
conversation_id: None,
user_input_message: UserInputMessage {
images: None,
content: "How about rustc?".into(),
user_input_message_context: None,
user_intent: None,
},
history: Some(vec![
ChatMessage::UserInputMessage(UserInputMessage {
images: None,
content: "What language is the linux kernel written in, and who wrote it?".into(),
user_input_message_context: None,
user_intent: None,
Expand Down
107 changes: 106 additions & 1 deletion crates/chat-cli/src/api_client/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use aws_smithy_types::Document;
use aws_smithy_types::{
Blob,
Document,
};
use serde::{
Deserialize,
Serialize,
Expand Down Expand Up @@ -565,17 +568,113 @@ impl From<GitState> for amzn_qdeveloper_streaming_client::types::GitState {
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageBlock {
pub format: ImageFormat,
pub source: ImageSource,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ImageFormat {
Gif,
Jpeg,
Png,
Webp,
}

impl std::str::FromStr for ImageFormat {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_lowercase().as_str() {
"gif" => Ok(ImageFormat::Gif),
"jpeg" => Ok(ImageFormat::Jpeg),
"jpg" => Ok(ImageFormat::Jpeg),
"png" => Ok(ImageFormat::Png),
"webp" => Ok(ImageFormat::Webp),
_ => Err(format!("Failed to parse '{}' as ImageFormat", s)),
}
}
}

impl From<ImageFormat> for amzn_codewhisperer_streaming_client::types::ImageFormat {
fn from(value: ImageFormat) -> Self {
match value {
ImageFormat::Gif => Self::Gif,
ImageFormat::Jpeg => Self::Jpeg,
ImageFormat::Png => Self::Png,
ImageFormat::Webp => Self::Webp,
}
}
}
impl From<ImageFormat> for amzn_qdeveloper_streaming_client::types::ImageFormat {
fn from(value: ImageFormat) -> Self {
match value {
ImageFormat::Gif => Self::Gif,
ImageFormat::Jpeg => Self::Jpeg,
ImageFormat::Png => Self::Png,
ImageFormat::Webp => Self::Webp,
}
}
}

#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ImageSource {
Bytes(Vec<u8>),
#[non_exhaustive]
Unknown,
}

impl From<ImageSource> for amzn_codewhisperer_streaming_client::types::ImageSource {
fn from(value: ImageSource) -> Self {
match value {
ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)),
ImageSource::Unknown => Self::Unknown,
}
}
}
impl From<ImageSource> for amzn_qdeveloper_streaming_client::types::ImageSource {
fn from(value: ImageSource) -> Self {
match value {
ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)),
ImageSource::Unknown => Self::Unknown,
}
}
}

impl From<ImageBlock> for amzn_codewhisperer_streaming_client::types::ImageBlock {
fn from(value: ImageBlock) -> Self {
Self::builder()
.format(value.format.into())
.source(value.source.into())
.build()
.expect("Failed to build ImageBlock")
}
}
impl From<ImageBlock> for amzn_qdeveloper_streaming_client::types::ImageBlock {
fn from(value: ImageBlock) -> Self {
Self::builder()
.format(value.format.into())
.source(value.source.into())
.build()
.expect("Failed to build ImageBlock")
}
}

#[derive(Debug, Clone)]
pub struct UserInputMessage {
pub content: String,
pub user_input_message_context: Option<UserInputMessageContext>,
pub user_intent: Option<UserIntent>,
pub images: Option<Vec<ImageBlock>>,
}

impl From<UserInputMessage> for amzn_codewhisperer_streaming_client::types::UserInputMessage {
fn from(value: UserInputMessage) -> Self {
Self::builder()
.content(value.content)
.set_images(value.images.map(|images| images.into_iter().map(Into::into).collect()))
.set_user_input_message_context(value.user_input_message_context.map(Into::into))
.set_user_intent(value.user_intent.map(Into::into))
.origin(amzn_codewhisperer_streaming_client::types::Origin::Cli)
Expand All @@ -588,6 +687,7 @@ impl From<UserInputMessage> for amzn_qdeveloper_streaming_client::types::UserInp
fn from(value: UserInputMessage) -> Self {
Self::builder()
.content(value.content)
.set_images(value.images.map(|images| images.into_iter().map(Into::into).collect()))
.set_user_input_message_context(value.user_input_message_context.map(Into::into))
.set_user_intent(value.user_intent.map(Into::into))
.origin(amzn_qdeveloper_streaming_client::types::Origin::Cli)
Expand Down Expand Up @@ -654,6 +754,10 @@ mod tests {
#[test]
fn build_user_input_message() {
let user_input_message = UserInputMessage {
images: Some(vec![ImageBlock {
format: ImageFormat::Png,
source: ImageSource::Bytes(vec![1, 2, 3]),
}]),
content: "test content".to_string(),
user_input_message_context: Some(UserInputMessageContext {
env_state: Some(EnvState {
Expand Down Expand Up @@ -690,6 +794,7 @@ mod tests {
assert_eq!(format!("{codewhisper_input:?}"), format!("{qdeveloper_input:?}"));

let minimal_message = UserInputMessage {
images: None,
content: "test content".to_string(),
user_input_message_context: None,
user_intent: None,
Expand Down
5 changes: 5 additions & 0 deletions crates/chat-cli/src/cli/chat/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ pub const MAX_USER_MESSAGE_SIZE: usize = 600_000;
pub const CONTEXT_WINDOW_SIZE: usize = 200_000;

pub const MAX_CHARS: usize = TokenCounter::token_to_chars(CONTEXT_WINDOW_SIZE); // Character-based warning threshold

pub const MAX_NUMBER_OF_IMAGES_PER_REQUEST: usize = 10;

/// In bytes - 10 MB
pub const MAX_IMAGE_SIZE: usize = 10 * 1024 * 1024;
7 changes: 7 additions & 0 deletions crates/chat-cli/src/cli/chat/conversation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use crate::api_client::model::{
AssistantResponseMessage,
ChatMessage,
ConversationState as FigConversationState,
ImageBlock,
Tool,
ToolInputSchema,
ToolResult,
Expand Down Expand Up @@ -294,6 +295,11 @@ impl ConversationState {
self.next_message = Some(UserMessage::new_tool_use_results(tool_results));
}

pub fn add_tool_results_with_images(&mut self, tool_results: Vec<ToolUseResult>, images: Vec<ImageBlock>) {
debug_assert!(self.next_message.is_none());
self.next_message = Some(UserMessage::new_tool_use_results_with_images(tool_results, images));
}

/// Sets the next user message with "cancelled" tool results.
pub fn abandon_tool_use(&mut self, tools_to_be_abandoned: Vec<QueuedTool>, deny_input: String) {
self.next_message = Some(UserMessage::new_cancelled_tool_uses(
Expand Down Expand Up @@ -415,6 +421,7 @@ impl ConversationState {
content: summary_content,
user_input_message_context: None,
user_intent: None,
images: None,
};

// If the last message contains tool uses, then add cancelled tool results to the summary
Expand Down
19 changes: 19 additions & 0 deletions crates/chat-cli/src/cli/chat/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use super::util::truncate_safe;
use crate::api_client::model::{
AssistantResponseMessage,
EnvState,
ImageBlock,
ToolResult,
ToolResultContentBlock,
ToolResultStatus,
Expand All @@ -33,6 +34,7 @@ pub struct UserMessage {
pub additional_context: String,
pub env_context: UserEnvContext,
pub content: UserMessageContent,
pub images: Option<Vec<ImageBlock>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -56,6 +58,7 @@ impl UserMessage {
/// environment [UserEnvContext].
pub fn new_prompt(prompt: String) -> Self {
Self {
images: None,
additional_context: String::new(),
env_context: UserEnvContext::generate_new(),
content: UserMessageContent::Prompt { prompt },
Expand All @@ -64,6 +67,7 @@ impl UserMessage {

pub fn new_cancelled_tool_uses<'a>(prompt: Option<String>, tool_use_ids: impl Iterator<Item = &'a str>) -> Self {
Self {
images: None,
additional_context: String::new(),
env_context: UserEnvContext::generate_new(),
content: UserMessageContent::CancelledToolUses {
Expand All @@ -88,13 +92,26 @@ impl UserMessage {
content: UserMessageContent::ToolUseResults {
tool_use_results: results,
},
images: None,
}
}

pub fn new_tool_use_results_with_images(results: Vec<ToolUseResult>, images: Vec<ImageBlock>) -> Self {
Self {
additional_context: String::new(),
env_context: UserEnvContext::generate_new(),
content: UserMessageContent::ToolUseResults {
tool_use_results: results,
},
images: Some(images),
}
}

/// Converts this message into a [UserInputMessage] to be stored in the history of
/// [api_client::model::ConversationState].
pub fn into_history_entry(self) -> UserInputMessage {
UserInputMessage {
images: None,
content: self.prompt().unwrap_or_default().to_string(),
user_input_message_context: Some(UserInputMessageContext {
env_state: self.env_context.env_state,
Expand Down Expand Up @@ -122,6 +139,7 @@ impl UserMessage {
_ => String::new(),
};
UserInputMessage {
images: self.images,
content: format!("{} {}", self.additional_context, formatted_prompt)
.trim()
.to_string(),
Expand Down Expand Up @@ -232,6 +250,7 @@ impl From<InvokeOutput> for ToolUseResultBlock {
match value.output {
OutputKind::Text(text) => Self::Text(text),
OutputKind::Json(value) => Self::Json(value),
OutputKind::Images(_) => Self::Text("See images data supplied".to_string()),
}
}
}
Expand Down
39 changes: 29 additions & 10 deletions crates/chat-cli/src/cli/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ use tool_manager::{
};
use tools::gh_issue::GhIssueContext;
use tools::{
OutputKind,
QueuedTool,
Tool,
ToolPermissions,
Expand All @@ -169,6 +170,7 @@ use tracing::{
warn,
};
use unicode_width::UnicodeWidthStr;
use util::images::RichImageBlock;
use util::{
animate_output,
play_notification_bell,
Expand Down Expand Up @@ -1283,14 +1285,6 @@ impl ChatContext {
// Otherwise continue with normal chat on 'n' or other responses
self.tool_use_status = ToolUseStatus::Idle;

if pending_tool_index.is_some() {
self.conversation_state.abandon_tool_use(tool_uses, user_input);
} else {
self.conversation_state.set_next_user_message(user_input).await;
}

let conv_state = self.conversation_state.as_sendable_conversation_state(true).await;

if self.interactive {
queue!(self.output, style::SetForegroundColor(Color::Magenta))?;
queue!(self.output, style::SetForegroundColor(Color::Reset))?;
Expand All @@ -1299,6 +1293,13 @@ impl ChatContext {
self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned()));
}

if pending_tool_index.is_some() {
self.conversation_state.abandon_tool_use(tool_uses, user_input);
} else {
self.conversation_state.set_next_user_message(user_input).await;
}

let conv_state = self.conversation_state.as_sendable_conversation_state(true).await;
self.send_tool_use_telemetry().await;

ChatState::HandleResponseStream(self.client.send_message(conv_state).await?)
Expand Down Expand Up @@ -2673,6 +2674,7 @@ impl ChatContext {

// Execute the requested tools.
let mut tool_results = vec![];
let mut image_blocks: Vec<RichImageBlock> = Vec::new();

for tool in tool_uses {
let mut tool_telemetry = self.tool_use_telemetry_events.entry(tool.id.clone());
Expand Down Expand Up @@ -2700,9 +2702,20 @@ impl ChatContext {
});
}
let tool_time = format!("{}.{}", tool_time.as_secs(), tool_time.subsec_millis());

match invoke_result {
Ok(result) => {
match result.output {
OutputKind::Text(ref text) => {
debug!("Output is Text: {}", text);
},
OutputKind::Json(ref json) => {
debug!("Output is JSON: {}", json);
},
OutputKind::Images(ref image) => {
image_blocks.extend(image.clone());
},
}

debug!("tool result output: {:#?}", result);
execute!(
self.output,
Expand Down Expand Up @@ -2762,7 +2775,13 @@ impl ChatContext {
}
}

self.conversation_state.add_tool_results(tool_results);
if !image_blocks.is_empty() {
let images = image_blocks.into_iter().map(|(block, _)| block).collect();
self.conversation_state
.add_tool_results_with_images(tool_results, images);
} else {
self.conversation_state.add_tool_results(tool_results);
}

self.send_tool_use_telemetry().await;
return Ok(ChatState::HandleResponseStream(
Expand Down
Loading
Loading