diff --git a/crates/chat-cli/src/api_client/model.rs b/crates/chat-cli/src/api_client/model.rs index 16457d09e6..c0b6415df6 100644 --- a/crates/chat-cli/src/api_client/model.rs +++ b/crates/chat-cli/src/api_client/model.rs @@ -1,10 +1,20 @@ +use std::collections::HashMap; + use aws_smithy_types::{ Blob, - Document, + Document as AwsDocument, +}; +use serde::de::{ + self, + MapAccess, + SeqAccess, + Visitor, }; use serde::{ Deserialize, + Deserializer, Serialize, + Serializer, }; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -125,8 +135,189 @@ impl TryFrom for amzn_qdeveloper_streaming_client::types::ChatMessa } } -/// Information about a tool that can be used. +/// Wrapper around [aws_smithy_types::Document]. +/// +/// Used primarily so we can implement [Serialize] and [Deserialize] for +/// [aws_smith_types::Document]. +#[derive(Debug, Clone)] +pub struct FigDocument(AwsDocument); + +impl From for FigDocument { + fn from(value: AwsDocument) -> Self { + Self(value) + } +} + +impl From for AwsDocument { + fn from(value: FigDocument) -> Self { + value.0 + } +} + +/// Internal type used only during serialization for `FigDocument` to avoid unnecessary cloning. #[derive(Debug, Clone)] +struct FigDocumentRef<'a>(&'a AwsDocument); + +impl Serialize for FigDocumentRef<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use aws_smithy_types::Number; + match self.0 { + AwsDocument::Null => serializer.serialize_unit(), + AwsDocument::Bool(b) => serializer.serialize_bool(*b), + AwsDocument::Number(n) => match n { + Number::PosInt(u) => serializer.serialize_u64(*u), + Number::NegInt(i) => serializer.serialize_i64(*i), + Number::Float(f) => serializer.serialize_f64(*f), + }, + AwsDocument::String(s) => serializer.serialize_str(s), + AwsDocument::Array(arr) => { + use serde::ser::SerializeSeq; + let mut seq = serializer.serialize_seq(Some(arr.len()))?; + for value in arr { + seq.serialize_element(&Self(value))?; + } + seq.end() + }, + AwsDocument::Object(m) => { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(m.len()))?; + for (k, v) in m { + map.serialize_entry(k, &Self(v))?; + } + map.end() + }, + } + } +} + +impl Serialize for FigDocument { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + FigDocumentRef(&self.0).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for FigDocument { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + use aws_smithy_types::Number; + + struct FigDocumentVisitor; + + impl<'de> Visitor<'de> for FigDocumentVisitor { + type Value = FigDocument; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("any valid JSON value") + } + + fn visit_bool(self, value: bool) -> Result + where + E: de::Error, + { + Ok(FigDocument(AwsDocument::Bool(value))) + } + + fn visit_i64(self, value: i64) -> Result + where + E: de::Error, + { + Ok(FigDocument(AwsDocument::Number(if value < 0 { + Number::NegInt(value) + } else { + Number::PosInt(value as u64) + }))) + } + + fn visit_u64(self, value: u64) -> Result + where + E: de::Error, + { + Ok(FigDocument(AwsDocument::Number(Number::PosInt(value)))) + } + + fn visit_f64(self, value: f64) -> Result + where + E: de::Error, + { + Ok(FigDocument(AwsDocument::Number(Number::Float(value)))) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + Ok(FigDocument(AwsDocument::String(value.to_owned()))) + } + + fn visit_string(self, value: String) -> Result + where + E: de::Error, + { + Ok(FigDocument(AwsDocument::String(value))) + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(FigDocument(AwsDocument::Null)) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + Deserialize::deserialize(deserializer) + } + + fn visit_unit(self) -> Result + where + E: de::Error, + { + Ok(FigDocument(AwsDocument::Null)) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut vec = Vec::new(); + + while let Some(elem) = seq.next_element::()? { + vec.push(elem.0); + } + + Ok(FigDocument(AwsDocument::Array(vec))) + } + + fn visit_map(self, mut access: M) -> Result + where + M: MapAccess<'de>, + { + let mut map = HashMap::new(); + + while let Some((key, value)) = access.next_entry::()? { + map.insert(key, value.0); + } + + Ok(FigDocument(AwsDocument::Object(map))) + } + } + + deserializer.deserialize_any(FigDocumentVisitor) + } +} + +/// Information about a tool that can be used. +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Tool { ToolSpecification(ToolSpecification), } @@ -148,7 +339,7 @@ impl From for amzn_qdeveloper_streaming_client::types::Tool { } /// The specification for the tool. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolSpecification { /// The name for the tool. pub name: String, @@ -181,33 +372,33 @@ impl From for amzn_qdeveloper_streaming_client::types::ToolSp } /// The input schema for the tool in JSON format. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolInputSchema { - pub json: Option, + pub json: Option, } impl From for amzn_codewhisperer_streaming_client::types::ToolInputSchema { fn from(value: ToolInputSchema) -> Self { - Self::builder().set_json(value.json).build() + Self::builder().set_json(value.json.map(Into::into)).build() } } impl From for amzn_qdeveloper_streaming_client::types::ToolInputSchema { fn from(value: ToolInputSchema) -> Self { - Self::builder().set_json(value.json).build() + Self::builder().set_json(value.json.map(Into::into)).build() } } /// Contains information about a tool that the model is requesting be run. The model uses the result /// from the tool to generate a response. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolUse { /// The ID for the tool request. pub tool_use_id: String, /// The name for the tool. pub name: String, /// The input to pass to the tool. - pub input: Document, + pub input: FigDocument, } impl From for amzn_codewhisperer_streaming_client::types::ToolUse { @@ -215,7 +406,7 @@ impl From for amzn_codewhisperer_streaming_client::types::ToolUse { Self::builder() .tool_use_id(value.tool_use_id) .name(value.name) - .input(value.input) + .input(value.input.into()) .build() .expect("building ToolUse should not fail") } @@ -226,7 +417,7 @@ impl From for amzn_qdeveloper_streaming_client::types::ToolUse { Self::builder() .tool_use_id(value.tool_use_id) .name(value.name) - .input(value.input) + .input(value.input.into()) .build() .expect("building ToolUse should not fail") } @@ -268,7 +459,7 @@ impl From for amzn_qdeveloper_streaming_client::types::ToolResult { #[derive(Debug, Clone)] pub enum ToolResultContentBlock { /// A tool result that is JSON format data. - Json(Document), + Json(AwsDocument), /// A tool result that is text. Text(String), } @@ -780,7 +971,7 @@ mod tests { name: "test tool name".to_string(), description: "test tool description".to_string(), input_schema: ToolInputSchema { - json: Some(Document::Null), + json: Some(AwsDocument::Null.into()), }, })]), }), @@ -814,7 +1005,9 @@ mod tests { tool_uses: Some(vec![ToolUse { tool_use_id: "tooluseid_test".to_string(), name: "tool_name_test".to_string(), - input: Document::Object([("key1".to_string(), Document::Null)].into_iter().collect()), + input: FigDocument(AwsDocument::Object( + [("key1".to_string(), AwsDocument::Null)].into_iter().collect(), + )), }]), }; let codewhisper_input = diff --git a/crates/chat-cli/src/cli/chat/command.rs b/crates/chat-cli/src/cli/chat/command.rs index 8b0d2c9518..006eb9b151 100644 --- a/crates/chat-cli/src/cli/chat/command.rs +++ b/crates/chat-cli/src/cli/chat/command.rs @@ -51,6 +51,13 @@ pub enum Command { subcommand: Option, }, Usage, + Import { + path: String, + }, + Export { + path: String, + force: bool, + }, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -811,6 +818,25 @@ impl Command { } }, "usage" => Self::Usage, + "import" => { + let Some(path) = parts.get(1) else { + return Err("path is required".to_string()); + }; + Self::Import { + path: (*path).to_string(), + } + }, + "export" => { + let force = parts.contains(&"-f") || parts.contains(&"--force"); + let Some(path) = parts.get(1) else { + return Err("path is required".to_string()); + }; + let mut path = (*path).to_string(); + if !path.ends_with(".json") { + path.push_str(".json"); + } + Self::Export { path, force } + }, unknown_command => { let looks_like_path = { let after_slash_command_str = parts[1..].join(" "); diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs index 4f83d5da5f..c3abf7967c 100644 --- a/crates/chat-cli/src/cli/chat/context.rs +++ b/crates/chat-cli/src/cli/chat/context.rs @@ -42,8 +42,10 @@ pub struct ContextConfig { #[allow(dead_code)] /// Manager for context files and profiles. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ContextManager { + #[serde(skip)] + #[serde(default = "default_context")] ctx: Arc, max_context_files_size: usize, @@ -798,6 +800,10 @@ fn validate_profile_name(name: &str) -> Result<()> { Ok(()) } +fn default_context() -> Arc { + Context::new() +} + #[cfg(test)] mod tests { use std::io::Stdout; diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index b5b039094b..94cdcf0817 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -9,6 +9,10 @@ use crossterm::{ execute, style, }; +use serde::{ + Deserialize, + Serialize, +}; use tracing::{ debug, error, @@ -42,8 +46,8 @@ use super::tools::{ QueuedTool, ToolOrigin, ToolSpec, - serde_value_to_document, }; +use super::util::serde_value_to_document; use crate::api_client::model::{ AssistantResponseMessage, ChatMessage, @@ -67,7 +71,7 @@ const CONTEXT_ENTRY_START_HEADER: &str = "--- CONTEXT ENTRY BEGIN ---\n"; const CONTEXT_ENTRY_END_HEADER: &str = "--- CONTEXT ENTRY END ---\n\n"; /// Tracks state related to an ongoing conversation. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConversationState { /// Randomly generated on creation. conversation_id: String, @@ -89,7 +93,8 @@ pub struct ConversationState { context_message_length: Option, /// Stores the latest conversation summary created by /compact latest_summary: Option, - updates: Option, + #[serde(skip)] + pub updates: Option, } impl ConversationState { @@ -797,7 +802,7 @@ pub enum TokenWarningLevel { impl From for ToolInputSchema { fn from(value: InputSchema) -> Self { Self { - json: Some(serde_value_to_document(value.0)), + json: Some(serde_value_to_document(value.0).into()), } } } diff --git a/crates/chat-cli/src/cli/chat/hooks.rs b/crates/chat-cli/src/cli/chat/hooks.rs index 036195ba17..f6d857ba38 100644 --- a/crates/chat-cli/src/cli/chat/hooks.rs +++ b/crates/chat-cli/src/cli/chat/hooks.rs @@ -119,14 +119,15 @@ pub enum HookTrigger { PerPrompt, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct CachedHook { output: String, + #[serde(skip)] expiry: Option, } /// Maps a hook name to a [`CachedHook`] -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct HookExecutor { pub global_cache: HashMap, pub profile_cache: HashMap, diff --git a/crates/chat-cli/src/cli/chat/message.rs b/crates/chat-cli/src/cli/chat/message.rs index 053c3f76c3..d6361c3c73 100644 --- a/crates/chat-cli/src/cli/chat/message.rs +++ b/crates/chat-cli/src/cli/chat/message.rs @@ -10,10 +10,12 @@ use super::consts::MAX_CURRENT_WORKING_DIRECTORY_LEN; use super::tools::{ InvokeOutput, OutputKind, +}; +use super::util::{ document_to_serde_value, serde_value_to_document, + truncate_safe, }; -use super::util::truncate_safe; use crate::api_client::model::{ AssistantResponseMessage, EnvState, @@ -355,7 +357,7 @@ impl From for ToolUse { Self { tool_use_id: value.id, name: value.name, - input: serde_value_to_document(value.args), + input: serde_value_to_document(value.args).into(), } } } @@ -365,7 +367,7 @@ impl From for AssistantToolUse { Self { id: value.tool_use_id, name: value.name, - args: document_to_serde_value(value.input), + args: document_to_serde_value(value.input.into()), } } } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index c94610cfd7..391a6e0924 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -269,7 +269,9 @@ const HELP_TEXT: &str = color_print::cstr! {" rm Remove file(s) from context [--global] clear Clear all files from current context [--global] hooks View and manage context hooks -/usage Show current session's context window usage +/usage Show current session's context window usage +/import Import conversation state from a JSON file +/export Export conversation state to a JSON file MCP: You can now configure the Amazon Q CLI to use MCP servers. \nLearn how: https://docs.aws.amazon.com/en_us/amazonq/latest/qdeveloper-ug/command-line-mcp.html @@ -2743,6 +2745,100 @@ impl ChatContext { skip_printing_tools: true, } }, + Command::Import { path } => { + macro_rules! tri { + ($v:expr) => { + match $v { + Ok(v) => v, + Err(err) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nFailed to import from {}: {}\n\n", &path, &err)), + style::SetAttribute(Attribute::Reset) + )?; + return Ok(ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + }); + }, + } + }; + } + + let contents = tri!(self.ctx.fs().read_to_string(&path).await); + let new_state: ConversationState = tri!(serde_json::from_str(&contents)); + self.conversation_state = new_state; + self.conversation_state.updates = Some(self.output.clone()); + + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\n✔ Imported conversation state from {}\n\n", &path)), + style::SetAttribute(Attribute::Reset) + )?; + + ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: true, + } + }, + Command::Export { path, force } => { + macro_rules! tri { + ($v:expr) => { + match $v { + Ok(v) => v, + Err(err) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nFailed to export to {}: {}\n\n", &path, &err)), + style::SetAttribute(Attribute::Reset) + )?; + return Ok(ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + }); + }, + } + }; + } + + let contents = tri!(serde_json::to_string_pretty(&self.conversation_state)); + if self.ctx.fs().exists(&path) && !force { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nFile at {} already exists. To overwrite, use -f or --force\n\n", + &path + )), + style::SetAttribute(Attribute::Reset) + )?; + return Ok(ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + }); + } + tri!(self.ctx.fs().write(&path, contents).await); + + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\n✔ Exported conversation state to {}\n\n", &path)), + style::SetAttribute(Attribute::Reset) + )?; + + ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: true, + } + }, }) } diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index d786a75f0a..b936555660 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -73,6 +73,8 @@ pub const COMMANDS: &[&str] = &[ "/compact", "/compact help", "/usage", + "/import", + "/export", ]; pub fn generate_prompt(current_profile: Option<&str>, warning: bool) -> String { diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index e558e10bea..3174de13ff 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -13,10 +13,6 @@ use std::path::{ PathBuf, }; -use aws_smithy_types::{ - Document, - Number as SmithyNumber, -}; use crossterm::style::Stylize; use custom_tool::CustomTool; use execute_bash::ExecuteBash; @@ -205,7 +201,7 @@ pub struct ToolSpec { pub tool_origin: ToolOrigin, } -#[derive(Debug, Clone, Deserialize, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq, Hash)] pub enum ToolOrigin { Native, McpServer(String), @@ -266,58 +262,6 @@ impl Default for OutputKind { } } -pub fn serde_value_to_document(value: serde_json::Value) -> Document { - match value { - serde_json::Value::Null => Document::Null, - serde_json::Value::Bool(bool) => Document::Bool(bool), - serde_json::Value::Number(number) => { - if let Some(num) = number.as_u64() { - Document::Number(SmithyNumber::PosInt(num)) - } else if number.as_i64().is_some_and(|n| n < 0) { - Document::Number(SmithyNumber::NegInt(number.as_i64().unwrap())) - } else { - Document::Number(SmithyNumber::Float(number.as_f64().unwrap_or_default())) - } - }, - serde_json::Value::String(string) => Document::String(string), - serde_json::Value::Array(vec) => { - Document::Array(vec.clone().into_iter().map(serde_value_to_document).collect::<_>()) - }, - serde_json::Value::Object(map) => Document::Object( - map.into_iter() - .map(|(k, v)| (k, serde_value_to_document(v))) - .collect::<_>(), - ), - } -} - -pub fn document_to_serde_value(value: Document) -> serde_json::Value { - use serde_json::Value; - match value { - Document::Object(map) => Value::Object( - map.into_iter() - .map(|(k, v)| (k, document_to_serde_value(v))) - .collect::<_>(), - ), - Document::Array(vec) => Value::Array(vec.clone().into_iter().map(document_to_serde_value).collect::<_>()), - Document::Number(number) => { - if let Ok(v) = TryInto::::try_into(number) { - Value::Number(v.into()) - } else if let Ok(v) = TryInto::::try_into(number) { - Value::Number(v.into()) - } else { - Value::Number( - serde_json::Number::from_f64(number.to_f64_lossy()) - .unwrap_or(serde_json::Number::from_f64(0.0).expect("converting from 0.0 will not fail")), - ) - } - }, - Document::String(s) => serde_json::Value::String(s), - Document::Bool(b) => serde_json::Value::Bool(b), - Document::Null => serde_json::Value::Null, - } -} - /// Performs tilde expansion and other required sanitization modifications for handling tool use /// path arguments. /// diff --git a/crates/chat-cli/src/cli/chat/util/mod.rs b/crates/chat-cli/src/cli/chat/util/mod.rs index d4958fabd7..ae36902e4b 100644 --- a/crates/chat-cli/src/cli/chat/util/mod.rs +++ b/crates/chat-cli/src/cli/chat/util/mod.rs @@ -6,6 +6,10 @@ pub mod ui; use std::io::Write; use std::time::Duration; +use aws_smithy_types::{ + Document, + Number as SmithyNumber, +}; use eyre::Result; use super::ChatError; @@ -127,6 +131,58 @@ pub fn drop_matched_context_files(files: &mut [(String, String)], limit: usize) Ok(dropped_files) } +pub fn serde_value_to_document(value: serde_json::Value) -> Document { + match value { + serde_json::Value::Null => Document::Null, + serde_json::Value::Bool(bool) => Document::Bool(bool), + serde_json::Value::Number(number) => { + if let Some(num) = number.as_u64() { + Document::Number(SmithyNumber::PosInt(num)) + } else if number.as_i64().is_some_and(|n| n < 0) { + Document::Number(SmithyNumber::NegInt(number.as_i64().unwrap())) + } else { + Document::Number(SmithyNumber::Float(number.as_f64().unwrap_or_default())) + } + }, + serde_json::Value::String(string) => Document::String(string), + serde_json::Value::Array(vec) => { + Document::Array(vec.clone().into_iter().map(serde_value_to_document).collect::<_>()) + }, + serde_json::Value::Object(map) => Document::Object( + map.into_iter() + .map(|(k, v)| (k, serde_value_to_document(v))) + .collect::<_>(), + ), + } +} + +pub fn document_to_serde_value(value: Document) -> serde_json::Value { + use serde_json::Value; + match value { + Document::Object(map) => Value::Object( + map.into_iter() + .map(|(k, v)| (k, document_to_serde_value(v))) + .collect::<_>(), + ), + Document::Array(vec) => Value::Array(vec.clone().into_iter().map(document_to_serde_value).collect::<_>()), + Document::Number(number) => { + if let Ok(v) = TryInto::::try_into(number) { + Value::Number(v.into()) + } else if let Ok(v) = TryInto::::try_into(number) { + Value::Number(v.into()) + } else { + Value::Number( + serde_json::Number::from_f64(number.to_f64_lossy()) + .unwrap_or(serde_json::Number::from_f64(0.0).expect("converting from 0.0 will not fail")), + ) + } + }, + Document::String(s) => serde_json::Value::String(s), + Document::Bool(b) => serde_json::Value::Bool(b), + Document::Null => serde_json::Value::Null, + } +} + #[cfg(test)] mod tests { use super::*;