diff --git a/crates/chat-cli/src/api_client/clients/streaming_client.rs b/crates/chat-cli/src/api_client/clients/streaming_client.rs index 2f262c08f0..df7064f732 100644 --- a/crates/chat-cli/src/api_client/clients/streaming_client.rs +++ b/crates/chat-cli/src/api_client/clients/streaming_client.rs @@ -291,6 +291,7 @@ mod tests { .send_message(ConversationState { conversation_id: None, user_input_message: UserInputMessage { + images: None, content: "Hello".into(), user_input_message_context: None, user_intent: None, @@ -315,12 +316,14 @@ mod tests { .send_message(ConversationState { conversation_id: None, user_input_message: UserInputMessage { + images: None, content: "How about rustc?".into(), user_input_message_context: None, user_intent: None, }, history: Some(vec![ ChatMessage::UserInputMessage(UserInputMessage { + images: None, content: "What language is the linux kernel written in, and who wrote it?".into(), user_input_message_context: None, user_intent: None, diff --git a/crates/chat-cli/src/api_client/model.rs b/crates/chat-cli/src/api_client/model.rs index ab44cdb119..16457d09e6 100644 --- a/crates/chat-cli/src/api_client/model.rs +++ b/crates/chat-cli/src/api_client/model.rs @@ -1,4 +1,7 @@ -use aws_smithy_types::Document; +use aws_smithy_types::{ + Blob, + Document, +}; use serde::{ Deserialize, Serialize, @@ -565,17 +568,113 @@ impl From for amzn_qdeveloper_streaming_client::types::GitState { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageBlock { + pub format: ImageFormat, + pub source: ImageSource, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ImageFormat { + Gif, + Jpeg, + Png, + Webp, +} + +impl std::str::FromStr for ImageFormat { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.trim().to_lowercase().as_str() { + "gif" => Ok(ImageFormat::Gif), + "jpeg" => Ok(ImageFormat::Jpeg), + "jpg" => Ok(ImageFormat::Jpeg), + "png" => Ok(ImageFormat::Png), + "webp" => Ok(ImageFormat::Webp), + _ => Err(format!("Failed to parse '{}' as ImageFormat", s)), + } + } +} + +impl From for amzn_codewhisperer_streaming_client::types::ImageFormat { + fn from(value: ImageFormat) -> Self { + match value { + ImageFormat::Gif => Self::Gif, + ImageFormat::Jpeg => Self::Jpeg, + ImageFormat::Png => Self::Png, + ImageFormat::Webp => Self::Webp, + } + } +} +impl From for amzn_qdeveloper_streaming_client::types::ImageFormat { + fn from(value: ImageFormat) -> Self { + match value { + ImageFormat::Gif => Self::Gif, + ImageFormat::Jpeg => Self::Jpeg, + ImageFormat::Png => Self::Png, + ImageFormat::Webp => Self::Webp, + } + } +} + +#[non_exhaustive] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ImageSource { + Bytes(Vec), + #[non_exhaustive] + Unknown, +} + +impl From for amzn_codewhisperer_streaming_client::types::ImageSource { + fn from(value: ImageSource) -> Self { + match value { + ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)), + ImageSource::Unknown => Self::Unknown, + } + } +} +impl From for amzn_qdeveloper_streaming_client::types::ImageSource { + fn from(value: ImageSource) -> Self { + match value { + ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)), + ImageSource::Unknown => Self::Unknown, + } + } +} + +impl From for amzn_codewhisperer_streaming_client::types::ImageBlock { + fn from(value: ImageBlock) -> Self { + Self::builder() + .format(value.format.into()) + .source(value.source.into()) + .build() + .expect("Failed to build ImageBlock") + } +} +impl From for amzn_qdeveloper_streaming_client::types::ImageBlock { + fn from(value: ImageBlock) -> Self { + Self::builder() + .format(value.format.into()) + .source(value.source.into()) + .build() + .expect("Failed to build ImageBlock") + } +} + #[derive(Debug, Clone)] pub struct UserInputMessage { pub content: String, pub user_input_message_context: Option, pub user_intent: Option, + pub images: Option>, } impl From for amzn_codewhisperer_streaming_client::types::UserInputMessage { fn from(value: UserInputMessage) -> Self { Self::builder() .content(value.content) + .set_images(value.images.map(|images| images.into_iter().map(Into::into).collect())) .set_user_input_message_context(value.user_input_message_context.map(Into::into)) .set_user_intent(value.user_intent.map(Into::into)) .origin(amzn_codewhisperer_streaming_client::types::Origin::Cli) @@ -588,6 +687,7 @@ impl From for amzn_qdeveloper_streaming_client::types::UserInp fn from(value: UserInputMessage) -> Self { Self::builder() .content(value.content) + .set_images(value.images.map(|images| images.into_iter().map(Into::into).collect())) .set_user_input_message_context(value.user_input_message_context.map(Into::into)) .set_user_intent(value.user_intent.map(Into::into)) .origin(amzn_qdeveloper_streaming_client::types::Origin::Cli) @@ -654,6 +754,10 @@ mod tests { #[test] fn build_user_input_message() { let user_input_message = UserInputMessage { + images: Some(vec![ImageBlock { + format: ImageFormat::Png, + source: ImageSource::Bytes(vec![1, 2, 3]), + }]), content: "test content".to_string(), user_input_message_context: Some(UserInputMessageContext { env_state: Some(EnvState { @@ -690,6 +794,7 @@ mod tests { assert_eq!(format!("{codewhisper_input:?}"), format!("{qdeveloper_input:?}")); let minimal_message = UserInputMessage { + images: None, content: "test content".to_string(), user_input_message_context: None, user_intent: None, diff --git a/crates/chat-cli/src/cli/chat/consts.rs b/crates/chat-cli/src/cli/chat/consts.rs index 6850f7efab..2880cdc65e 100644 --- a/crates/chat-cli/src/cli/chat/consts.rs +++ b/crates/chat-cli/src/cli/chat/consts.rs @@ -17,3 +17,8 @@ pub const MAX_USER_MESSAGE_SIZE: usize = 600_000; pub const CONTEXT_WINDOW_SIZE: usize = 200_000; pub const MAX_CHARS: usize = TokenCounter::token_to_chars(CONTEXT_WINDOW_SIZE); // Character-based warning threshold + +pub const MAX_NUMBER_OF_IMAGES_PER_REQUEST: usize = 10; + +/// In bytes - 10 MB +pub const MAX_IMAGE_SIZE: usize = 10 * 1024 * 1024; diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index dca132cdce..f860832fd7 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -42,6 +42,7 @@ use crate::api_client::model::{ AssistantResponseMessage, ChatMessage, ConversationState as FigConversationState, + ImageBlock, Tool, ToolInputSchema, ToolResult, @@ -294,6 +295,11 @@ impl ConversationState { self.next_message = Some(UserMessage::new_tool_use_results(tool_results)); } + pub fn add_tool_results_with_images(&mut self, tool_results: Vec, images: Vec) { + debug_assert!(self.next_message.is_none()); + self.next_message = Some(UserMessage::new_tool_use_results_with_images(tool_results, images)); + } + /// Sets the next user message with "cancelled" tool results. pub fn abandon_tool_use(&mut self, tools_to_be_abandoned: Vec, deny_input: String) { self.next_message = Some(UserMessage::new_cancelled_tool_uses( @@ -415,6 +421,7 @@ impl ConversationState { content: summary_content, user_input_message_context: None, user_intent: None, + images: None, }; // If the last message contains tool uses, then add cancelled tool results to the summary diff --git a/crates/chat-cli/src/cli/chat/message.rs b/crates/chat-cli/src/cli/chat/message.rs index 9cf7b700fc..053c3f76c3 100644 --- a/crates/chat-cli/src/cli/chat/message.rs +++ b/crates/chat-cli/src/cli/chat/message.rs @@ -17,6 +17,7 @@ use super::util::truncate_safe; use crate::api_client::model::{ AssistantResponseMessage, EnvState, + ImageBlock, ToolResult, ToolResultContentBlock, ToolResultStatus, @@ -33,6 +34,7 @@ pub struct UserMessage { pub additional_context: String, pub env_context: UserEnvContext, pub content: UserMessageContent, + pub images: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -56,6 +58,7 @@ impl UserMessage { /// environment [UserEnvContext]. pub fn new_prompt(prompt: String) -> Self { Self { + images: None, additional_context: String::new(), env_context: UserEnvContext::generate_new(), content: UserMessageContent::Prompt { prompt }, @@ -64,6 +67,7 @@ impl UserMessage { pub fn new_cancelled_tool_uses<'a>(prompt: Option, tool_use_ids: impl Iterator) -> Self { Self { + images: None, additional_context: String::new(), env_context: UserEnvContext::generate_new(), content: UserMessageContent::CancelledToolUses { @@ -88,6 +92,18 @@ impl UserMessage { content: UserMessageContent::ToolUseResults { tool_use_results: results, }, + images: None, + } + } + + pub fn new_tool_use_results_with_images(results: Vec, images: Vec) -> Self { + Self { + additional_context: String::new(), + env_context: UserEnvContext::generate_new(), + content: UserMessageContent::ToolUseResults { + tool_use_results: results, + }, + images: Some(images), } } @@ -95,6 +111,7 @@ impl UserMessage { /// [api_client::model::ConversationState]. pub fn into_history_entry(self) -> UserInputMessage { UserInputMessage { + images: None, content: self.prompt().unwrap_or_default().to_string(), user_input_message_context: Some(UserInputMessageContext { env_state: self.env_context.env_state, @@ -122,6 +139,7 @@ impl UserMessage { _ => String::new(), }; UserInputMessage { + images: self.images, content: format!("{} {}", self.additional_context, formatted_prompt) .trim() .to_string(), @@ -232,6 +250,7 @@ impl From for ToolUseResultBlock { match value.output { OutputKind::Text(text) => Self::Text(text), OutputKind::Json(value) => Self::Json(value), + OutputKind::Images(_) => Self::Text("See images data supplied".to_string()), } } } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 42febea965..b0fb7ee9b7 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -156,6 +156,7 @@ use tool_manager::{ }; use tools::gh_issue::GhIssueContext; use tools::{ + OutputKind, QueuedTool, Tool, ToolPermissions, @@ -169,6 +170,7 @@ use tracing::{ warn, }; use unicode_width::UnicodeWidthStr; +use util::images::RichImageBlock; use util::{ animate_output, play_notification_bell, @@ -1283,14 +1285,6 @@ impl ChatContext { // Otherwise continue with normal chat on 'n' or other responses self.tool_use_status = ToolUseStatus::Idle; - if pending_tool_index.is_some() { - self.conversation_state.abandon_tool_use(tool_uses, user_input); - } else { - self.conversation_state.set_next_user_message(user_input).await; - } - - let conv_state = self.conversation_state.as_sendable_conversation_state(true).await; - if self.interactive { queue!(self.output, style::SetForegroundColor(Color::Magenta))?; queue!(self.output, style::SetForegroundColor(Color::Reset))?; @@ -1299,6 +1293,13 @@ impl ChatContext { self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned())); } + if pending_tool_index.is_some() { + self.conversation_state.abandon_tool_use(tool_uses, user_input); + } else { + self.conversation_state.set_next_user_message(user_input).await; + } + + let conv_state = self.conversation_state.as_sendable_conversation_state(true).await; self.send_tool_use_telemetry().await; ChatState::HandleResponseStream(self.client.send_message(conv_state).await?) @@ -2673,6 +2674,7 @@ impl ChatContext { // Execute the requested tools. let mut tool_results = vec![]; + let mut image_blocks: Vec = Vec::new(); for tool in tool_uses { let mut tool_telemetry = self.tool_use_telemetry_events.entry(tool.id.clone()); @@ -2700,9 +2702,20 @@ impl ChatContext { }); } let tool_time = format!("{}.{}", tool_time.as_secs(), tool_time.subsec_millis()); - match invoke_result { Ok(result) => { + match result.output { + OutputKind::Text(ref text) => { + debug!("Output is Text: {}", text); + }, + OutputKind::Json(ref json) => { + debug!("Output is JSON: {}", json); + }, + OutputKind::Images(ref image) => { + image_blocks.extend(image.clone()); + }, + } + debug!("tool result output: {:#?}", result); execute!( self.output, @@ -2762,7 +2775,13 @@ impl ChatContext { } } - self.conversation_state.add_tool_results(tool_results); + if !image_blocks.is_empty() { + let images = image_blocks.into_iter().map(|(block, _)| block).collect(); + self.conversation_state + .add_tool_results_with_images(tool_results, images); + } else { + self.conversation_state.add_tool_results(tool_results); + } self.send_tool_use_telemetry().await; return Ok(ChatState::HandleResponseStream( diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs index b2904af339..99a0f7f43f 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_read.rs @@ -28,6 +28,11 @@ use super::{ format_path, sanitize_path_tool_arg, }; +use crate::cli::chat::util::images::{ + handle_images_from_paths, + is_supported_image_type, + pre_process, +}; use crate::platform::Context; #[derive(Debug, Clone, Deserialize)] @@ -36,6 +41,7 @@ pub enum FsRead { Line(FsLine), Directory(FsDirectory), Search(FsSearch), + Image(FsImage), } impl FsRead { @@ -44,6 +50,7 @@ impl FsRead { FsRead::Line(fs_line) => fs_line.validate(ctx).await, FsRead::Directory(fs_directory) => fs_directory.validate(ctx).await, FsRead::Search(fs_search) => fs_search.validate(ctx).await, + FsRead::Image(fs_image) => fs_image.validate(ctx).await, } } @@ -52,6 +59,7 @@ impl FsRead { FsRead::Line(fs_line) => fs_line.queue_description(ctx, updates).await, FsRead::Directory(fs_directory) => fs_directory.queue_description(updates), FsRead::Search(fs_search) => fs_search.queue_description(updates), + FsRead::Image(fs_image) => fs_image.queue_description(updates), } } @@ -60,7 +68,54 @@ impl FsRead { FsRead::Line(fs_line) => fs_line.invoke(ctx, updates).await, FsRead::Directory(fs_directory) => fs_directory.invoke(ctx, updates).await, FsRead::Search(fs_search) => fs_search.invoke(ctx, updates).await, + FsRead::Image(fs_image) => fs_image.invoke(ctx, updates).await, + } + } +} + +/// Read images from given paths. +#[derive(Debug, Clone, Deserialize)] +pub struct FsImage { + pub image_paths: Vec, +} + +impl FsImage { + pub async fn validate(&mut self, ctx: &Context) -> Result<()> { + for path in &self.image_paths { + let path = sanitize_path_tool_arg(ctx, path); + if let Some(path) = path.to_str() { + let processed_path = pre_process(ctx, path); + if !is_supported_image_type(&processed_path) { + bail!("'{}' is not a supported image type", &processed_path); + } + let is_file = ctx.fs().symlink_metadata(&processed_path).await?.is_file(); + if !is_file { + bail!("'{}' is not a file", &processed_path); + } + } else { + bail!("Unable to parse path"); + } } + Ok(()) + } + + pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { + let pre_processed_paths: Vec = self.image_paths.iter().map(|path| pre_process(ctx, path)).collect(); + let valid_images = handle_images_from_paths(updates, &pre_processed_paths); + Ok(InvokeOutput { + output: OutputKind::Images(valid_images), + }) + } + + pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { + queue!( + updates, + style::Print("Reading images: \n"), + style::SetForegroundColor(Color::Green), + style::Print(&self.image_paths.join("\n")), + style::ResetColor, + )?; + Ok(()) } } diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index 247b6122ec..e558e10bea 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -32,6 +32,7 @@ use thinking::Thinking; use use_aws::UseAws; use super::consts::MAX_TOOL_RESPONSE_SIZE; +use super::util::images::RichImageBlocks; use crate::platform::Context; /// Represents an executable tool use. @@ -246,6 +247,7 @@ impl InvokeOutput { match &self.output { OutputKind::Text(s) => s.as_str(), OutputKind::Json(j) => j.as_str().unwrap_or_default(), + OutputKind::Images(_) => "", } } } @@ -255,6 +257,7 @@ impl InvokeOutput { pub enum OutputKind { Text(String), Json(serde_json::Value), + Images(RichImageBlocks), } impl Default for OutputKind { @@ -320,7 +323,7 @@ pub fn document_to_serde_value(value: Document) -> serde_json::Value { /// /// Required since path arguments are defined by the model. #[allow(dead_code)] -fn sanitize_path_tool_arg(ctx: &Context, path: impl AsRef) -> PathBuf { +pub fn sanitize_path_tool_arg(ctx: &Context, path: impl AsRef) -> PathBuf { let mut res = PathBuf::new(); // Expand `~` only if it is the first part. let mut path = path.as_ref().components(); diff --git a/crates/chat-cli/src/cli/chat/tools/tool_index.json b/crates/chat-cli/src/cli/chat/tools/tool_index.json index b3ff683bc1..b4b5279a69 100644 --- a/crates/chat-cli/src/cli/chat/tools/tool_index.json +++ b/crates/chat-cli/src/cli/chat/tools/tool_index.json @@ -19,7 +19,7 @@ }, "fs_read": { "name": "fs_read", - "description": "Tool for reading files (for example, `cat -n`) and directories (for example, `ls -la`). The behavior of this tool is determined by the `mode` parameter. The available modes are:\n- line: Show lines in a file, given by an optional `start_line` and optional `end_line`.\n- directory: List directory contents. Content is returned in the \"long format\" of ls (that is, `ls -la`).\n- search: Search for a pattern in a file. The pattern is a string. The matching is case insensitive.\n\nExample Usage:\n1. Read all lines from a file: command=\"line\", path=\"/path/to/file.txt\"\n2. Read the last 5 lines from a file: command=\"line\", path=\"/path/to/file.txt\", start_line=-5\n3. List the files in the home directory: command=\"line\", path=\"~\"\n4. Recursively list files in a directory to a max depth of 2: command=\"line\", path=\"/path/to/directory\", depth=2\n5. Search for all instances of \"test\" in a file: command=\"search\", path=\"/path/to/file.txt\", pattern=\"test\"\n", + "description": "Tool for reading files (for example, `cat -n`), directories (for example, `ls -la`) and images. If user has supplied paths that appear to be leading to images, you should use this tool right away using Image mode. The behavior of this tool is determined by the `mode` parameter. The available modes are:\n- line: Show lines in a file, given by an optional `start_line` and optional `end_line`.\n- directory: List directory contents. Content is returned in the \"long format\" of ls (that is, `ls -la`).\n- search: Search for a pattern in a file. The pattern is a string. The matching is case insensitive.\n\nExample Usage:\n1. Read all lines from a file: command=\"line\", path=\"/path/to/file.txt\"\n2. Read the last 5 lines from a file: command=\"line\", path=\"/path/to/file.txt\", start_line=-5\n3. List the files in the home directory: command=\"line\", path=\"~\"\n4. Recursively list files in a directory to a max depth of 2: command=\"line\", path=\"/path/to/directory\", depth=2\n5. Search for all instances of \"test\" in a file: command=\"search\", path=\"/path/to/file.txt\", pattern=\"test\"\n", "input_schema": { "type": "object", "properties": { @@ -27,10 +27,22 @@ "description": "Path to the file or directory. The path should be absolute, or otherwise start with ~ for the user's home.", "type": "string" }, + "image_paths": { + "description": "List of paths to the images. This is currently supported by the Image mode.", + "type": "array", + "items": { + "type": "string" + } + }, "mode": { "type": "string", - "enum": ["Line", "Directory", "Search"], - "description": "The mode to run in: `Line`, `Directory`, `Search`. `Line` and `Search` are only for text files, and `Directory` is only for directories." + "enum": [ + "Line", + "Directory", + "Search", + "Image" + ], + "description": "The mode to run in: `Line`, `Directory`, `Search`. `Line` and `Search` are only for text files, and `Directory` is only for directories. `Image` is for image files, in this mode `image_paths` is required." }, "start_line": { "type": "integer", diff --git a/crates/chat-cli/src/cli/chat/util/images.rs b/crates/chat-cli/src/cli/chat/util/images.rs new file mode 100644 index 0000000000..e3aa8ca9ab --- /dev/null +++ b/crates/chat-cli/src/cli/chat/util/images.rs @@ -0,0 +1,301 @@ +use std::fs; +use std::io::Write; +use std::path::Path; +use std::str::FromStr; + +use crossterm::execute; +use crossterm::style::{ + self, + Color, +}; +use serde::{ + Deserialize, + Serialize, +}; + +use crate::api_client::model::{ + ImageBlock, + ImageFormat, + ImageSource, +}; +use crate::cli::chat::consts::{ + MAX_IMAGE_SIZE, + MAX_NUMBER_OF_IMAGES_PER_REQUEST, +}; +use crate::platform::{ + self, + Context, +}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ImageMetadata { + pub filepath: String, + /// The size of the image in bytes + pub size: u64, + pub filename: String, +} + +pub type RichImageBlocks = Vec; +pub type RichImageBlock = (ImageBlock, ImageMetadata); + +/// Macos screenshots insert a NNBSP character rather than a space between the timestamp and AM/PM +/// part. An example of a screenshot name is: /path-to/Screenshot 2025-03-13 at 1.46.32 PM.png +/// +/// However, the model will just treat it as a normal space and return the wrong path string to the +/// `fs_read` tool. This will lead to file-not-found errors. +pub fn pre_process(ctx: &Context, path: &str) -> String { + if ctx.platform().os() == platform::Os::Mac && path.contains("Screenshot") { + let mac_screenshot_regex = + regex::Regex::new(r"Screenshot \d{4}-\d{2}-\d{2} at \d{1,2}\.\d{2}\.\d{2} [AP]M").unwrap(); + if mac_screenshot_regex.is_match(path) { + if let Some(pos) = path.find(" at ") { + let mut new_path = String::new(); + new_path.push_str(&path[..pos + 4]); + new_path.push_str(&path[pos + 4..].replace(" ", "\u{202F}")); + return new_path; + } + } + } + + path.to_string() +} + +pub fn handle_images_from_paths(output: &mut impl Write, paths: &[String]) -> RichImageBlocks { + let mut extracted_images = Vec::new(); + let mut seen_args = std::collections::HashSet::new(); + + for path in paths.iter() { + if seen_args.contains(path) { + continue; + } + seen_args.insert(path); + if is_supported_image_type(path) { + if let Some(image_block) = get_image_block_from_file_path(path) { + let filename = Path::new(path) + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let image_size = fs::metadata(path).map(|m| m.len()).unwrap_or_default(); + + extracted_images.push((image_block, ImageMetadata { + filename, + filepath: path.to_string(), + size: image_size, + })); + } + } + } + + let (mut valid_images, images_exceeding_size_limit): (RichImageBlocks, RichImageBlocks) = extracted_images + .into_iter() + .partition(|(_, metadata)| metadata.size as usize <= MAX_IMAGE_SIZE); + + if valid_images.len() > MAX_NUMBER_OF_IMAGES_PER_REQUEST { + execute!( + &mut *output, + style::SetForegroundColor(Color::DarkYellow), + style::Print(format!( + "\nMore than {} images detected. Extra ones will be dropped.\n", + MAX_NUMBER_OF_IMAGES_PER_REQUEST + )), + style::SetForegroundColor(Color::Reset) + ) + .ok(); + valid_images.truncate(MAX_NUMBER_OF_IMAGES_PER_REQUEST); + } + + if !images_exceeding_size_limit.is_empty() { + execute!( + &mut *output, + style::SetForegroundColor(Color::DarkYellow), + style::Print(format!( + "\nThe following images are dropped due to exceeding size limit ({}MB):\n", + MAX_IMAGE_SIZE / (1024 * 1024) + )), + style::SetForegroundColor(Color::Reset) + ) + .ok(); + for (_, metadata) in &images_exceeding_size_limit { + let image_size_str = if metadata.size > 1024 * 1024 { + format!("{:.2} MB", metadata.size as f64 / (1024.0 * 1024.0)) + } else if metadata.size > 1024 { + format!("{:.2} KB", metadata.size as f64 / 1024.0) + } else { + format!("{} bytes", metadata.size) + }; + execute!( + &mut *output, + style::SetForegroundColor(Color::DarkYellow), + style::Print(format!(" - {} ({})\n", metadata.filename, image_size_str)), + style::SetForegroundColor(Color::Reset) + ) + .ok(); + } + } + valid_images +} + +/// This function checks if the file path has a supported image type +/// and returns true if it does, otherwise false. +/// Supported image types are: jpg, jpeg, png, gif, webp +/// +/// # Arguments +/// +/// * `maybe_file_path` - A string slice that may or may not be a valid file path +/// +/// # Returns +/// +/// * `true` if the file path has a supported image type +/// * `false` otherwise +pub fn is_supported_image_type(maybe_file_path: &str) -> bool { + let supported_image_types = ["jpg", "jpeg", "png", "gif", "webp"]; + if let Some(extension) = maybe_file_path.split('.').last() { + return supported_image_types.contains(&extension.trim().to_lowercase().as_str()); + } + false +} + +pub fn get_image_block_from_file_path(maybe_file_path: &str) -> Option { + if !is_supported_image_type(maybe_file_path) { + return None; + } + + let file_path = Path::new(maybe_file_path); + if !file_path.exists() { + return None; + } + + let image_bytes = fs::read(file_path); + if image_bytes.is_err() { + return None; + } + + let image_format = ImageFormat::from_str(file_path.extension()?.to_str()?.to_lowercase().as_str()); + + if image_format.is_err() { + return None; + } + + let image_bytes = image_bytes.unwrap(); + let image_block = ImageBlock { + format: image_format.unwrap(), + source: ImageSource::Bytes(image_bytes), + }; + Some(image_block) +} + +#[cfg(test)] +mod tests { + + use std::str::FromStr; + use std::sync::Arc; + + use bstr::ByteSlice; + + use super::*; + use crate::cli::chat::util::shared_writer::{ + SharedWriter, + TestWriterWithSink, + }; + + #[test] + fn test_is_supported_image_type() { + let test_cases = vec![ + ("image.jpg", true), + ("image.jpeg", true), + ("image.png", true), + ("image.gif", true), + ("image.webp", true), + ("image.txt", false), + ("image", false), + ]; + + for (path, expected) in test_cases { + assert_eq!(is_supported_image_type(path), expected, "Failed for path: {}", path); + } + } + + #[test] + fn test_get_image_format_from_ext() { + assert_eq!(ImageFormat::from_str("jpg"), Ok(ImageFormat::Jpeg)); + assert_eq!(ImageFormat::from_str("JPEG"), Ok(ImageFormat::Jpeg)); + assert_eq!(ImageFormat::from_str("png"), Ok(ImageFormat::Png)); + assert_eq!(ImageFormat::from_str("gif"), Ok(ImageFormat::Gif)); + assert_eq!(ImageFormat::from_str("webp"), Ok(ImageFormat::Webp)); + assert_eq!( + ImageFormat::from_str("txt"), + Err("Failed to parse 'txt' as ImageFormat".to_string()) + ); + } + + #[test] + fn test_handle_images_from_paths() { + let temp_dir = tempfile::tempdir().unwrap(); + let image_path = temp_dir.path().join("test_image.jpg"); + std::fs::write(&image_path, b"fake_image_data").unwrap(); + + let mut output = SharedWriter::stdout(); + + let images = handle_images_from_paths(&mut output, &[image_path.to_string_lossy().to_string()]); + + assert_eq!(images.len(), 1); + assert_eq!(images[0].1.filename, "test_image.jpg"); + assert_eq!(images[0].1.filepath, image_path.to_string_lossy()); + } + + #[test] + fn test_get_image_block_from_file_path() { + let temp_dir = tempfile::tempdir().unwrap(); + let image_path = temp_dir.path().join("test_image.png"); + std::fs::write(&image_path, b"fake_image_data").unwrap(); + + let image_block = get_image_block_from_file_path(&image_path.to_string_lossy()); + assert!(image_block.is_some()); + let image_block = image_block.unwrap(); + assert_eq!(image_block.format, ImageFormat::Png); + if let ImageSource::Bytes(bytes) = image_block.source { + assert_eq!(bytes, b"fake_image_data"); + } else { + panic!("Expected ImageSource::Bytes"); + } + } + + #[test] + fn test_handle_images_size_limit_exceeded() { + let temp_dir = tempfile::tempdir().unwrap(); + let large_image_path = temp_dir.path().join("large_image.jpg"); + let large_image_size = MAX_IMAGE_SIZE as usize + 1; + std::fs::write(&large_image_path, vec![0; large_image_size]).unwrap(); + let buf = Arc::new(std::sync::Mutex::new(Vec::::new())); + let test_writer = TestWriterWithSink { sink: buf.clone() }; + let mut output = SharedWriter::new(test_writer.clone()); + + let images = handle_images_from_paths(&mut output, &[large_image_path.to_string_lossy().to_string()]); + let content = test_writer.get_content(); + let output_str = content.to_str_lossy(); + print!("{}", output_str); + assert!(output_str.contains("The following images are dropped due to exceeding size limit (10MB):")); + assert!(output_str.contains("- large_image.jpg (10.00 MB)")); + assert!(images.is_empty()); + } + + #[test] + fn test_handle_images_number_exceeded() { + let temp_dir = tempfile::tempdir().unwrap(); + + let mut paths = vec![]; + for i in 0..(MAX_NUMBER_OF_IMAGES_PER_REQUEST + 2) { + let image_path = temp_dir.path().join(format!("image_{}.jpg", i)); + paths.push(image_path.to_string_lossy().to_string()); + std::fs::write(&image_path, b"fake_image_data").unwrap(); + } + + let mut output = SharedWriter::stdout(); + + let images = handle_images_from_paths(&mut output, &paths); + + assert_eq!(images.len(), MAX_NUMBER_OF_IMAGES_PER_REQUEST); + } +} diff --git a/crates/chat-cli/src/cli/chat/util/mod.rs b/crates/chat-cli/src/cli/chat/util/mod.rs index 187f7fdbeb..0543879b69 100644 --- a/crates/chat-cli/src/cli/chat/util/mod.rs +++ b/crates/chat-cli/src/cli/chat/util/mod.rs @@ -1,3 +1,4 @@ +pub mod images; pub mod issue; pub mod shared_writer; pub mod ui; diff --git a/crates/chat-cli/src/platform/mod.rs b/crates/chat-cli/src/platform/mod.rs index 55c55f06e8..b62fdaa101 100644 --- a/crates/chat-cli/src/platform/mod.rs +++ b/crates/chat-cli/src/platform/mod.rs @@ -4,6 +4,7 @@ pub mod diagnostics; mod env; mod fs; +mod os; mod providers; mod sysinfo; @@ -11,6 +12,10 @@ use std::sync::Arc; pub use env::Env; pub use fs::Fs; +pub use os::{ + Os, + Platform, +}; pub use providers::{ EnvProvider, FsProvider, @@ -28,6 +33,7 @@ pub struct Context { fs: Fs, env: Env, sysinfo: SysInfo, + platform: Platform, } impl Context { @@ -38,11 +44,13 @@ impl Context { fs: Fs::new(), env: Env::new(), sysinfo: SysInfo::new(), + platform: Platform::new(), }), false => Arc::new_cyclic(|_| Self { fs: Default::default(), env: Default::default(), sysinfo: SysInfo::default(), + platform: Platform::new(), }), } } @@ -62,6 +70,10 @@ impl Context { pub fn sysinfo(&self) -> &SysInfo { &self.sysinfo } + + pub fn platform(&self) -> &Platform { + &self.platform + } } #[derive(Default, Debug)] @@ -69,6 +81,7 @@ pub struct ContextBuilder { fs: Option, env: Option, sysinfo: Option, + platform: Option, } impl ContextBuilder { @@ -81,7 +94,13 @@ impl ContextBuilder { let fs = self.fs.unwrap_or_default(); let env = self.env.unwrap_or_default(); let sysinfo = self.sysinfo.unwrap_or_default(); - Arc::new_cyclic(|_| Context { fs, env, sysinfo }) + let platform = self.platform.unwrap_or_default(); + Arc::new_cyclic(|_| Context { + fs, + env, + sysinfo, + platform, + }) } /// Builds an immutable [Context] using fake implementations for each field by default. @@ -89,7 +108,13 @@ impl ContextBuilder { let fs = self.fs.unwrap_or_default(); let env = self.env.unwrap_or_default(); let sysinfo = self.sysinfo.unwrap_or_default(); - Arc::new_cyclic(|_| Context { fs, env, sysinfo }) + let platform = self.platform.unwrap_or_default(); + Arc::new_cyclic(|_| Context { + fs, + env, + sysinfo, + platform, + }) } pub fn with_env(mut self, env: Env) -> Self { diff --git a/crates/chat-cli/src/platform/os.rs b/crates/chat-cli/src/platform/os.rs new file mode 100644 index 0000000000..87609c4332 --- /dev/null +++ b/crates/chat-cli/src/platform/os.rs @@ -0,0 +1,113 @@ +use std::fmt; + +use serde::Serialize; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[non_exhaustive] +pub enum Os { + Mac, + Linux, + Windows, +} + +impl Os { + pub fn current() -> Self { + #[cfg(target_os = "macos")] + { + return Self::Mac; + } + + #[cfg(target_os = "linux")] + { + return Self::Linux; + } + + #[cfg(target_os = "windows")] + { + return Self::Windows; + } + + #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))] + { + compile_error!("unsupported platform"); + } + + // This line should never be reached due to the compile_error above, + // but it's needed to satisfy the compiler + #[allow(unreachable_code)] + { + panic!("unsupported platform"); + } + } + + pub fn all() -> &'static [Self] { + &[Self::Mac, Self::Linux, Self::Windows] + } + + pub fn as_str(&self) -> &'static str { + match self { + Self::Mac => "macos", + Self::Linux => "linux", + Self::Windows => "windows", + } + } +} + +impl fmt::Display for Os { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +#[derive(Default, Debug, Clone)] +pub struct Platform(inner::Inner); + +mod inner { + use super::*; + + #[derive(Default, Debug, Clone)] + pub(super) enum Inner { + #[default] + Real, + Fake(Os), + } +} + +impl Platform { + /// Returns a real implementation of [Platform]. + pub fn new() -> Self { + Self(inner::Inner::Real) + } + + /// Returns a new fake [Platform]. + pub fn new_fake(os: Os) -> Self { + Self(inner::Inner::Fake(os)) + } + + /// Returns the current [Os]. + pub fn os(&self) -> Os { + use inner::Inner; + match &self.0 { + Inner::Real => Os::current(), + Inner::Fake(os) => *os, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_platform() { + let platform = Platform::default(); + + for os in Os::all() { + let platform = Platform::new_fake(*os); + assert_eq!(&platform.os(), os); + + let _ = os.as_str(); + println!("{os:?} {os}"); + } + } +} diff --git a/crates/fig_api_client/src/clients/streaming_client.rs b/crates/fig_api_client/src/clients/streaming_client.rs index 9b491c3651..5541b24a20 100644 --- a/crates/fig_api_client/src/clients/streaming_client.rs +++ b/crates/fig_api_client/src/clients/streaming_client.rs @@ -282,6 +282,7 @@ mod tests { .send_message(ConversationState { conversation_id: None, user_input_message: UserInputMessage { + images: None, content: "Hello".into(), user_input_message_context: None, user_intent: None, @@ -306,12 +307,14 @@ mod tests { .send_message(ConversationState { conversation_id: None, user_input_message: UserInputMessage { + images: None, content: "How about rustc?".into(), user_input_message_context: None, user_intent: None, }, history: Some(vec![ ChatMessage::UserInputMessage(UserInputMessage { + images: None, content: "What language is the linux kernel written in, and who wrote it?".into(), user_input_message_context: None, user_intent: None, diff --git a/crates/fig_api_client/src/model.rs b/crates/fig_api_client/src/model.rs index 5381b5581f..1a54c0f37e 100644 --- a/crates/fig_api_client/src/model.rs +++ b/crates/fig_api_client/src/model.rs @@ -1,4 +1,7 @@ -use aws_smithy_types::Document; +use aws_smithy_types::{ + Blob, + Document, +}; use serde::{ Deserialize, Serialize, @@ -642,12 +645,14 @@ pub struct UserInputMessage { pub content: String, pub user_input_message_context: Option, pub user_intent: Option, + pub images: Option>, } impl From for amzn_codewhisperer_streaming_client::types::UserInputMessage { fn from(value: UserInputMessage) -> Self { Self::builder() .content(value.content) + .set_images(value.images.map(|images| images.into_iter().map(Into::into).collect())) .set_user_input_message_context(value.user_input_message_context.map(Into::into)) .set_user_intent(value.user_intent.map(Into::into)) .origin(amzn_codewhisperer_streaming_client::types::Origin::Cli) @@ -660,6 +665,7 @@ impl From for amzn_qdeveloper_streaming_client::types::UserInp fn from(value: UserInputMessage) -> Self { Self::builder() .content(value.content) + .set_images(value.images.map(|images| images.into_iter().map(Into::into).collect())) .set_user_input_message_context(value.user_input_message_context.map(Into::into)) .set_user_intent(value.user_intent.map(Into::into)) .origin(amzn_qdeveloper_streaming_client::types::Origin::Cli) @@ -701,6 +707,85 @@ impl From for amzn_qdeveloper_streaming_client::types:: } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImageBlock { + pub format: ImageFormat, + pub source: ImageSource, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ImageFormat { + Gif, + Jpeg, + Png, + Webp, +} + +impl From for amzn_codewhisperer_streaming_client::types::ImageFormat { + fn from(value: ImageFormat) -> Self { + match value { + ImageFormat::Gif => Self::Gif, + ImageFormat::Jpeg => Self::Jpeg, + ImageFormat::Png => Self::Png, + ImageFormat::Webp => Self::Webp, + } + } +} +impl From for amzn_qdeveloper_streaming_client::types::ImageFormat { + fn from(value: ImageFormat) -> Self { + match value { + ImageFormat::Gif => Self::Gif, + ImageFormat::Jpeg => Self::Jpeg, + ImageFormat::Png => Self::Png, + ImageFormat::Webp => Self::Webp, + } + } +} + +#[non_exhaustive] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ImageSource { + Bytes(Vec), + #[non_exhaustive] + Unknown, +} + +impl From for amzn_codewhisperer_streaming_client::types::ImageSource { + fn from(value: ImageSource) -> Self { + match value { + ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)), + ImageSource::Unknown => Self::Unknown, + } + } +} +impl From for amzn_qdeveloper_streaming_client::types::ImageSource { + fn from(value: ImageSource) -> Self { + match value { + ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)), + ImageSource::Unknown => Self::Unknown, + } + } +} + +impl From for amzn_codewhisperer_streaming_client::types::ImageBlock { + fn from(value: ImageBlock) -> Self { + Self::builder() + .format(value.format.into()) + .source(value.source.into()) + .build() + .expect("Failed to build ImageBlock") + } +} +impl From for amzn_qdeveloper_streaming_client::types::ImageBlock { + fn from(value: ImageBlock) -> Self { + Self::builder() + .format(value.format.into()) + .source(value.source.into()) + .build() + .expect("Failed to build ImageBlock") + } +} + #[derive(Debug, Clone)] pub enum UserIntent { ApplyCommonBestPractices, @@ -729,6 +814,10 @@ mod tests { #[test] fn build_user_input_message() { let user_input_message = UserInputMessage { + images: Some(vec![ImageBlock { + format: ImageFormat::Png, + source: ImageSource::Bytes(vec![1, 2, 3]), + }]), content: "test content".to_string(), user_input_message_context: Some(UserInputMessageContext { shell_state: Some(ShellState { @@ -776,6 +865,7 @@ mod tests { content: "test content".to_string(), user_input_message_context: None, user_intent: None, + images: None, }; let codewhisper_minimal =