From 840066aedac8ddbcf02937da7c66124335f65113 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Wed, 30 Apr 2025 11:40:26 -0700 Subject: [PATCH 01/26] first commit --- crates/chat-cli/src/cli/chat/tool_manager.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 39c450f522..f5835fcefb 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -458,10 +458,12 @@ enum OutOfSpecName { EmptyDescription(String), } +// type ClientInProgress = Either + +#[derive(Default)] /// Manages the lifecycle and interactions with tools from various sources, including MCP servers. /// This struct is responsible for initializing tools, handling tool requests, and maintaining /// a cache of available prompts from connected servers. -#[derive(Default)] pub struct ToolManager { /// Unique identifier for the current conversation. /// This ID is used to track and associate tools with a specific chat session. @@ -471,6 +473,12 @@ pub struct ToolManager { /// These clients are used to communicate with MCP servers. pub clients: HashMap>, + /// Map of server names to client instances that are currently being initialized. + /// This tracks MCP server clients that are in the process of being set up but are not yet + /// fully ready for use. Once initialization is complete, these clients will be moved to + /// the main `clients` collection. + _clients_in_progress: Arc>>>, + /// Cache for prompts collected from different servers. /// Key: prompt name /// Value: a list of PromptBundle that has a prompt of this name. From 88db63868efbe034dd4dd9f201532896adccf311 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Fri, 2 May 2025 13:43:44 -0700 Subject: [PATCH 02/26] adds messenger trait --- crates/chat-cli/src/cli/chat/tool_manager.rs | 2 -- crates/mcp_client/src/client.rs | 6 +++-- crates/mcp_client/src/lib.rs | 2 ++ crates/mcp_client/src/messenger.rs | 28 ++++++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 crates/mcp_client/src/messenger.rs diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index f5835fcefb..b6ce44b57b 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -458,8 +458,6 @@ enum OutOfSpecName { EmptyDescription(String), } -// type ClientInProgress = Either - #[derive(Default)] /// Manages the lifecycle and interactions with tools from various sources, including MCP servers. /// This struct is responsible for initializing tools, handling tool requests, and maintaining diff --git a/crates/mcp_client/src/client.rs b/crates/mcp_client/src/client.rs index 01b3794013..aedb027798 100644 --- a/crates/mcp_client/src/client.rs +++ b/crates/mcp_client/src/client.rs @@ -210,8 +210,10 @@ where { /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization /// - /// Also done is the spawn of a background task that constantly listens for incoming messages - /// from the server. + /// Also done are the following: + /// - Spawns task for listening to server driven workflows + /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server + /// capabilities received pub async fn init(&self) -> Result { let transport_ref = self.transport.clone(); let server_name = self.server_name.clone(); diff --git a/crates/mcp_client/src/lib.rs b/crates/mcp_client/src/lib.rs index d631f70654..19f23b809a 100644 --- a/crates/mcp_client/src/lib.rs +++ b/crates/mcp_client/src/lib.rs @@ -1,9 +1,11 @@ pub mod client; pub mod error; pub mod facilitator_types; +pub mod messenger; pub mod server; pub mod transport; pub use client::*; pub use facilitator_types::*; +pub use messenger::*; pub use transport::*; diff --git a/crates/mcp_client/src/messenger.rs b/crates/mcp_client/src/messenger.rs new file mode 100644 index 0000000000..9a112a6e8f --- /dev/null +++ b/crates/mcp_client/src/messenger.rs @@ -0,0 +1,28 @@ +use crate::{ + PromptsListResult, + ResourceTemplatesListResult, + ResourcesListResult, + ToolsListResult, +}; + +/// An interface that abstracts the implementation for information delivery from client and its +/// consumer. It is through this interface secondary information (i.e. information that are needed +/// to make requests to mcp servers) are obtained passively. Consumers of client can of course +/// choose to "actively" retrieve these information via explicitly making these requests. +pub trait Messenger: Send + Sync + 'static { + /// Sends the result of a tools list operation to the consumer + /// This function is used to deliver information about available tools + fn send_tools_list_result(result: ToolsListResult); + + /// Sends the result of a prompts list operation to the consumer + /// This function is used to deliver information about available prompts + fn send_prompts_list_result(result: PromptsListResult); + + /// Sends the result of a resources list operation to the consumer + /// This function is used to deliver information about available resources + fn send_resources_list_result(result: ResourcesListResult); + + /// Sends the result of a resource templates list operation to the consumer + /// This function is used to deliver information about available resource templates + fn send_resource_templates_list_result(result: ResourceTemplatesListResult); +} From 1147b5e5ed3f31b1f56f5d03a513cb7104dd2638 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Fri, 2 May 2025 16:05:42 -0700 Subject: [PATCH 03/26] supplies server init with messenger --- crates/chat-cli/src/cli/chat/tool_manager.rs | 3 +- .../src/cli/chat/tools/custom_tool.rs | 11 +- crates/mcp_client/src/client.rs | 146 ++++++++++++------ crates/mcp_client/src/facilitator_types.rs | 18 +++ crates/mcp_client/src/messenger.rs | 47 +++++- 5 files changed, 166 insertions(+), 59 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index b6ce44b57b..90431990f2 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -56,6 +56,7 @@ use crate::api_client::model::{ }; use crate::mcp_client::{ JsonRpcResponse, + NullMessenger, PromptGet, }; use crate::telemetry::send_mcp_server_init; @@ -528,7 +529,7 @@ impl ToolManager { let tool_specs_clone = tool_specs.clone(); let conversation_id = conversation_id.clone(); async move { - let tool_spec = client_clone.init().await; + let tool_spec = client_clone.init(None::).await; let mut sanitized_mapping = HashMap::::new(); match tool_spec { Ok((server_name, specs)) => { diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 43580fecea..60bac86a4f 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -27,6 +27,7 @@ use crate::mcp_client::{ JsonRpcResponse, JsonRpcStdioTransport, MessageContent, + Messenger, PromptGet, ServerCapabilities, StdioTransport, @@ -87,7 +88,7 @@ impl CustomToolClient { }) } - pub async fn init(&self) -> Result<(String, Vec)> { + pub async fn init(&self, messenger: Option) -> Result<(String, Vec)> { match self { CustomToolClient::Stdio { client, @@ -96,13 +97,11 @@ impl CustomToolClient { } => { // We'll need to first initialize. This is the handshake every client and server // needs to do before proceeding to anything else - let init_resp = client.init().await?; + let cap = client.init(messenger).await?; // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 // So don't worry about the tidiness for now - let is_tool_supported = init_resp - .get("result") - .is_some_and(|r| r.get("capabilities").is_some_and(|cap| cap.get("tools").is_some())); - server_capabilities.write().await.replace(init_resp); + let is_tool_supported = cap.tools.is_some(); + server_capabilities.write().await.replace(cap); // Assuming a shape of return as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#listing-tools let tools = if is_tool_supported { // And now we make the server tell us what tools they have diff --git a/crates/mcp_client/src/client.rs b/crates/mcp_client/src/client.rs index aedb027798..c062d626a4 100644 --- a/crates/mcp_client/src/client.rs +++ b/crates/mcp_client/src/client.rs @@ -37,15 +37,16 @@ use crate::{ JsonRpcResponse, Listener as _, LogListener, + Messenger, PaginationSupportedOps, PromptGet, PromptsListResult, ResourceTemplatesListResult, ResourcesListResult, + ServerCapabilities, ToolsListResult, }; -pub type ServerCapabilities = serde_json::Value; pub type ClientInfo = serde_json::Value; pub type StdioTransport = JsonRpcStdioTransport; @@ -214,7 +215,7 @@ where /// - Spawns task for listening to server driven workflows /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server /// capabilities received - pub async fn init(&self) -> Result { + pub async fn init(&self, messenger: Option) -> Result { let transport_ref = self.transport.clone(); let server_name = self.server_name.clone(); @@ -295,63 +296,112 @@ where let client_cap = ClientCapabilities::from(self.client_info.clone()); serde_json::json!(client_cap) }); - let server_capabilities = self.request("initialize", init_params).await?; - if let Err(e) = examine_server_capabilities(&server_capabilities) { + let init_resp = self.request("initialize", init_params).await?; + if let Err(e) = examine_server_capabilities(&init_resp) { return Err(ClientError::NegotiationError(format!( "Client {} has failed to negotiate server capabilities with server: {:?}", self.server_name, e ))); } + let cap = { + let result = init_resp.result.ok_or(ClientError::NegotiationError(format!( + "Server {} init resp is missing result", + self.server_name + )))?; + let cap = result + .get("capabilities") + .ok_or(ClientError::NegotiationError(format!( + "Server {} init resp result is missing capabilities", + self.server_name + )))? + .clone(); + serde_json::from_value::(cap)? + }; self.notify("initialized", None).await?; // TODO: group this into examine_server_capabilities // Prefetch prompts in the background. We should only do this after the server has been // initialized - if let Some(res) = &server_capabilities.result { - if let Some(cap) = res.get("capabilities") { - if cap.get("prompts").is_some() { - self.is_prompts_out_of_date.store(true, Ordering::Relaxed); - let client_ref = (*self).clone(); - tokio::spawn(async move { - let Ok(resp) = client_ref.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client_ref.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); - return; - }; - let Some(prompts) = result.get("prompts") else { - tracing::warn!( - "Prompt list query result contained no field named prompts for {0}", - client_ref.server_name - ); - return; - }; - let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { - tracing::error!( - "Prompt list query deserialization failed for {0}", - client_ref.server_name - ); - return; - }; - let Ok(mut lock) = client_ref.prompt_gets.write() else { - tracing::error!( - "Failed to obtain write lock for prompt list query for {0}", - client_ref.server_name - ); - return; - }; - for prompt in prompts { - let name = prompt.name.clone(); - lock.insert(name, prompt); - } - }); + if cap.prompts.is_some() { + self.is_prompts_out_of_date.store(true, Ordering::Relaxed); + let client_ref = (*self).clone(); + tokio::spawn(async move { + let Ok(resp) = client_ref.request("prompts/list", None).await else { + tracing::error!("Prompt list query failed for {0}", client_ref.server_name); + return; + }; + let Some(result) = resp.result else { + tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); + return; + }; + let Some(prompts) = result.get("prompts") else { + tracing::warn!( + "Prompt list query result contained no field named prompts for {0}", + client_ref.server_name + ); + return; + }; + let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { + tracing::error!( + "Prompt list query deserialization failed for {0}", + client_ref.server_name + ); + return; + }; + let Ok(mut lock) = client_ref.prompt_gets.write() else { + tracing::error!( + "Failed to obtain write lock for prompt list query for {0}", + client_ref.server_name + ); + return; + }; + for prompt in prompts { + let name = prompt.name.clone(); + lock.insert(name, prompt); } - } + }); + } + if let (Some(_), Some(messenger)) = (&cap.tools, messenger) { + let client_ref = (*self).clone(); + let msger = messenger.clone(); + tokio::spawn(async move { + let resp = match client_ref.request("tools/list", None).await { + Ok(resp) => resp, + Err(e) => { + tracing::error!("Failed to retrieve tool list from {}: {:?}", client_ref.server_name, e); + return; + }, + }; + if let Some(error) = resp.error { + let msg = format!( + "Failed to retrieve tool list for {}: {:?}", + client_ref.server_name, error + ); + tracing::error!("{}", &msg); + return; + } + let Some(result) = resp.result else { + tracing::error!("Tool list response from {} is missing result", client_ref.server_name); + return; + }; + let tool_list_result = match serde_json::from_value::(result) { + Ok(result) => result, + Err(e) => { + tracing::error!( + "Failed to deserialize tool result from {}: {:?}", + client_ref.server_name, + e + ); + return; + }, + }; + if let Err(e) = msger.send_tools_list_result(tool_list_result).await { + tracing::error!("Failed to send tool result through messenger {:?}", e); + } + }); } - Ok(serde_json::to_value(server_capabilities)?) + Ok(cap) } /// Sends a request to the server associated. @@ -520,6 +570,7 @@ mod tests { use serde_json::Value; use super::*; + use crate::NullMessenger; const TEST_BIN_OUT_DIR: &str = "target/debug"; const TEST_SERVER_NAME: &str = "test_mcp_server"; @@ -609,8 +660,9 @@ mod tests { client: &mut Client, cap_sent: serde_json::Value, ) -> Result<(), Box> { + let test_messenger = Some(NullMessenger); // Test init - let _ = client.init().await.expect("Client init failed"); + let _ = client.init(test_messenger).await.expect("Client init failed"); tokio::time::sleep(time::Duration::from_millis(1500)).await; let client_capabilities_sent = client .request("verify_init_ack_sent", None) diff --git a/crates/mcp_client/src/facilitator_types.rs b/crates/mcp_client/src/facilitator_types.rs index ba56982046..908f555bd2 100644 --- a/crates/mcp_client/src/facilitator_types.rs +++ b/crates/mcp_client/src/facilitator_types.rs @@ -227,3 +227,21 @@ pub struct Resource { /// Resource contents pub contents: ResourceContents, } + +/// Represents the capabilities supported by a Model Context Protocol server +/// This is the "capabilities" field in the result of a response for init +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerCapabilities { + /// Configuration for server logging capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub logging: Option, + /// Configuration for prompt-related capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub prompts: Option, + /// Configuration for resource management capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub resources: Option, + /// Configuration for tool integration capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, +} diff --git a/crates/mcp_client/src/messenger.rs b/crates/mcp_client/src/messenger.rs index 9a112a6e8f..24ffab1c97 100644 --- a/crates/mcp_client/src/messenger.rs +++ b/crates/mcp_client/src/messenger.rs @@ -1,3 +1,5 @@ +use thiserror::Error; + use crate::{ PromptsListResult, ResourceTemplatesListResult, @@ -9,20 +11,55 @@ use crate::{ /// consumer. It is through this interface secondary information (i.e. information that are needed /// to make requests to mcp servers) are obtained passively. Consumers of client can of course /// choose to "actively" retrieve these information via explicitly making these requests. -pub trait Messenger: Send + Sync + 'static { +#[async_trait::async_trait] +pub trait Messenger: Clone + Send + Sync + 'static { /// Sends the result of a tools list operation to the consumer /// This function is used to deliver information about available tools - fn send_tools_list_result(result: ToolsListResult); + async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError>; /// Sends the result of a prompts list operation to the consumer /// This function is used to deliver information about available prompts - fn send_prompts_list_result(result: PromptsListResult); + async fn send_prompts_list_result(&self, result: PromptsListResult) -> Result<(), MessengerError>; /// Sends the result of a resources list operation to the consumer /// This function is used to deliver information about available resources - fn send_resources_list_result(result: ResourcesListResult); + async fn send_resources_list_result(&self, result: ResourcesListResult) -> Result<(), MessengerError>; /// Sends the result of a resource templates list operation to the consumer /// This function is used to deliver information about available resource templates - fn send_resource_templates_list_result(result: ResourceTemplatesListResult); + async fn send_resource_templates_list_result( + &self, + result: ResourceTemplatesListResult, + ) -> Result<(), MessengerError>; +} + +#[derive(Clone, Debug, Error)] +pub enum MessengerError { + #[error("{0}")] + Custom(String), +} + +#[derive(Clone)] +pub struct NullMessenger; + +#[async_trait::async_trait] +impl Messenger for NullMessenger { + async fn send_tools_list_result(&self, _result: ToolsListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_prompts_list_result(&self, _result: PromptsListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_resources_list_result(&self, _result: ResourcesListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_resource_templates_list_result( + &self, + _result: ResourceTemplatesListResult, + ) -> Result<(), MessengerError> { + Ok(()) + } } From be973493e73729ff92a615959092da3621824c07 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Mon, 5 May 2025 11:24:03 -0700 Subject: [PATCH 04/26] makes messenger to be used by dynamic dispatch instead --- crates/chat-cli/src/cli/chat/tool_manager.rs | 2 +- crates/chat-cli/src/cli/chat/tools/custom_tool.rs | 4 ++-- crates/mcp_client/src/client.rs | 13 +++++++------ crates/mcp_client/src/messenger.rs | 12 ++++++++++-- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 90431990f2..e01fedf582 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -529,7 +529,7 @@ impl ToolManager { let tool_specs_clone = tool_specs.clone(); let conversation_id = conversation_id.clone(); async move { - let tool_spec = client_clone.init(None::).await; + let tool_spec = client_clone.init().await; let mut sanitized_mapping = HashMap::::new(); match tool_spec { Ok((server_name, specs)) => { diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 60bac86a4f..cf80062e54 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -88,7 +88,7 @@ impl CustomToolClient { }) } - pub async fn init(&self, messenger: Option) -> Result<(String, Vec)> { + pub async fn init(&self) -> Result<(String, Vec)> { match self { CustomToolClient::Stdio { client, @@ -97,7 +97,7 @@ impl CustomToolClient { } => { // We'll need to first initialize. This is the handshake every client and server // needs to do before proceeding to anything else - let cap = client.init(messenger).await?; + let cap = client.init().await?; // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 // So don't worry about the tidiness for now let is_tool_supported = cap.tools.is_some(); diff --git a/crates/mcp_client/src/client.rs b/crates/mcp_client/src/client.rs index c062d626a4..9b22b7a1cb 100644 --- a/crates/mcp_client/src/client.rs +++ b/crates/mcp_client/src/client.rs @@ -124,6 +124,7 @@ pub struct Client { server_process_id: Option, client_info: serde_json::Value, current_id: Arc, + pub messenger: Option>, pub prompt_gets: Arc>>, pub is_prompts_out_of_date: Arc, } @@ -139,6 +140,7 @@ impl Clone for Client { server_process_id: None, client_info: self.client_info.clone(), current_id: self.current_id.clone(), + messenger: None, prompt_gets: self.prompt_gets.clone(), is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), } @@ -186,6 +188,7 @@ impl Client { server_process_id, client_info, current_id: Arc::new(AtomicU64::new(0)), + messenger: None, prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), }) @@ -215,7 +218,7 @@ where /// - Spawns task for listening to server driven workflows /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server /// capabilities received - pub async fn init(&self, messenger: Option) -> Result { + pub async fn init(&self) -> Result { let transport_ref = self.transport.clone(); let server_name = self.server_name.clone(); @@ -361,9 +364,9 @@ where } }); } - if let (Some(_), Some(messenger)) = (&cap.tools, messenger) { + if let (Some(_), Some(messenger)) = (&cap.tools, &self.messenger) { let client_ref = (*self).clone(); - let msger = messenger.clone(); + let msger = messenger.duplicate(); tokio::spawn(async move { let resp = match client_ref.request("tools/list", None).await { Ok(resp) => resp, @@ -570,7 +573,6 @@ mod tests { use serde_json::Value; use super::*; - use crate::NullMessenger; const TEST_BIN_OUT_DIR: &str = "target/debug"; const TEST_SERVER_NAME: &str = "test_mcp_server"; @@ -660,9 +662,8 @@ mod tests { client: &mut Client, cap_sent: serde_json::Value, ) -> Result<(), Box> { - let test_messenger = Some(NullMessenger); // Test init - let _ = client.init(test_messenger).await.expect("Client init failed"); + let _ = client.init().await.expect("Client init failed"); tokio::time::sleep(time::Duration::from_millis(1500)).await; let client_capabilities_sent = client .request("verify_init_ack_sent", None) diff --git a/crates/mcp_client/src/messenger.rs b/crates/mcp_client/src/messenger.rs index 24ffab1c97..14e519a6b2 100644 --- a/crates/mcp_client/src/messenger.rs +++ b/crates/mcp_client/src/messenger.rs @@ -12,7 +12,7 @@ use crate::{ /// to make requests to mcp servers) are obtained passively. Consumers of client can of course /// choose to "actively" retrieve these information via explicitly making these requests. #[async_trait::async_trait] -pub trait Messenger: Clone + Send + Sync + 'static { +pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { /// Sends the result of a tools list operation to the consumer /// This function is used to deliver information about available tools async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError>; @@ -31,6 +31,10 @@ pub trait Messenger: Clone + Send + Sync + 'static { &self, result: ResourceTemplatesListResult, ) -> Result<(), MessengerError>; + + /// Creates a duplicate of the messenger object + /// This function is used to create a new instance of the messenger with the same configuration + fn duplicate(&self) -> Box; } #[derive(Clone, Debug, Error)] @@ -39,7 +43,7 @@ pub enum MessengerError { Custom(String), } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct NullMessenger; #[async_trait::async_trait] @@ -62,4 +66,8 @@ impl Messenger for NullMessenger { ) -> Result<(), MessengerError> { Ok(()) } + + fn duplicate(&self) -> Box { + Box::new(NullMessenger) + } } From f58fc9e79faa4aa1dbe65a5395f00341ae23948a Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Tue, 6 May 2025 10:53:46 -0700 Subject: [PATCH 05/26] loads tools in the background --- Cargo.lock | 52 + crates/chat-cli/src/cli/chat/tool_manager.rs | 460 ++++-- .../src/cli/chat/tools/custom_tool.rs | 31 +- crates/mcp_client/src/client.rs | 6 + crates/q_chat/Cargo.toml | 58 + crates/q_chat/src/tool_manager/mod.rs | 4 + .../src/tool_manager/server_messenger.rs | 116 ++ .../q_chat/src/tool_manager/tool_manager.rs | 1332 +++++++++++++++++ 8 files changed, 1900 insertions(+), 159 deletions(-) create mode 100644 crates/q_chat/Cargo.toml create mode 100644 crates/q_chat/src/tool_manager/mod.rs create mode 100644 crates/q_chat/src/tool_manager/server_messenger.rs create mode 100644 crates/q_chat/src/tool_manager/tool_manager.rs diff --git a/Cargo.lock b/Cargo.lock index e55148dc8b..ecc83f3aec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7028,6 +7028,58 @@ dependencies = [ ] [[package]] +<<<<<<< HEAD +======= +name = "q_chat" +version = "1.10.0" +dependencies = [ + "anstream", + "async-trait", + "aws-smithy-types", + "bstr", + "clap", + "color-print", + "convert_case 0.8.0", + "crossterm", + "eyre", + "fig_api_client", + "fig_auth", + "fig_diagnostic", + "fig_os_shim", + "fig_settings", + "fig_telemetry", + "fig_util", + "futures", + "glob", + "mcp_client", + "rand 0.9.0", + "regex", + "rustyline", + "serde", + "serde_json", + "shell-color", + "shell-words", + "shellexpand", + "shlex", + "similar", + "skim", + "spinners", + "strip-ansi-escapes", + "syntect", + "tempfile", + "thiserror 2.0.12", + "time", + "tokio", + "tracing", + "tracing-subscriber", + "unicode-width 0.2.0", + "url", + "uuid", + "winnow 0.6.22", +] + +[[package]] +>>>>>>> ca627e83 (loads tools in the background) name = "q_cli" version = "1.10.0" dependencies = [ diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index e01fedf582..bd97f3f431 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -5,6 +5,7 @@ use std::hash::{ }; use std::io::Write; use std::path::PathBuf; +use std::sync::atomic::AtomicBool; use std::sync::mpsc::RecvTimeoutError; use std::sync::{ Arc, @@ -23,28 +24,35 @@ use futures::{ StreamExt, stream, }; +use regex::Regex; use serde::{ Deserialize, Serialize, }; use thiserror::Error; use tokio::sync::Mutex; -use tracing::error; +use tracing::{ + error, + warn, +}; -use super::command::PromptsGetCommand; -use super::message::AssistantToolUse; -use super::tools::custom_tool::{ +use crate::command::PromptsGetCommand; +use crate::message::AssistantToolUse; +use crate::tool_manager::server_messenger::{ + ServerMessengerBuilder, + UpdateEventMessage, +}; +use crate::tools::custom_tool::{ CustomTool, CustomToolClient, CustomToolConfig, }; -use super::tools::execute_bash::ExecuteBash; -use super::tools::fs_read::FsRead; -use super::tools::fs_write::FsWrite; -use super::tools::gh_issue::GhIssue; -use super::tools::thinking::Thinking; -use super::tools::use_aws::UseAws; -use super::tools::{ +use crate::tools::execute_bash::ExecuteBash; +use crate::tools::fs_read::FsRead; +use crate::tools::fs_write::FsWrite; +use crate::tools::gh_issue::GhIssue; +use crate::tools::use_aws::UseAws; +use crate::tools::{ Tool, ToolOrigin, ToolSpec, @@ -88,7 +96,7 @@ pub enum GetPromptError { /// Messages used for communication between the tool initialization thread and the loading /// display thread. These messages control the visual loading indicators shown to /// the user during tool initialization. -enum LoadingMsg { +pub enum LoadingMsg { /// Indicates a new tool is being initialized and should be added to the loading /// display. The String parameter is the name of the tool being initialized. Add(String), @@ -102,6 +110,10 @@ enum LoadingMsg { /// Represents a warning that occurred during tool initialization. /// Contains the name of the server that generated the warning and the warning message. Warn { name: String, msg: eyre::Report }, + /// Signals that the loading display thread should terminate. + /// This is sent when all tool initialization is complete or when the application is shutting + /// down. + Terminate, } /// Represents the state of a loading indicator for a tool being initialized. @@ -225,7 +237,7 @@ impl ToolManagerBuilder { let (tx, rx) = std::sync::mpsc::channel::(); // Using a hand rolled thread because it's just easier to do this than do deal with the Send // requirements that comes with holding onto the stdout lock. - let loading_display_task = std::thread::spawn(move || { + let loading_display_task = tokio::task::spawn_blocking(move || { let stdout = std::io::stdout(); let mut stdout_lock = stdout.lock(); let mut loading_servers = HashMap::::new(); @@ -293,6 +305,17 @@ impl ToolManagerBuilder { queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; stdout_lock.flush()?; }, + LoadingMsg::Terminate => { + if !loading_servers.is_empty() { + let msg = loading_servers.iter().fold(String::new(), |mut acc, (server_name, _)| { + acc.push_str(format!("\n - {server_name}").as_str()); + acc + }); + let msg = eyre::eyre!(msg); + queue_incomplete_load_message(&msg, &mut stdout_lock)?; + } + break; + }, }, Err(RecvTimeoutError::Timeout) => { spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); @@ -311,10 +334,58 @@ impl ToolManagerBuilder { Ok::<_, eyre::Report>(()) }); let mut clients = HashMap::>::new(); + let load_msg_sender = tx.clone(); + let conv_id_clone = conversation_id.clone(); + let regex = Arc::new(Regex::new(VALID_TOOL_NAME)?); + let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); + tokio::spawn(async move { + let mut is_in_display = true; + while let Some(msg) = msg_rx.recv().await { + // For now we will treat every list result as if they contain the + // complete set of tools. This is not necessarily true in the future when + // request method on the mcp client no longer buffers all the pages from + // list calls. + match msg { + UpdateEventMessage::ToolsListResult { server_name, result } => { + error!("## background: from {server_name}: {:?}", result); + let mut specs = result + .tools + .into_iter() + .filter_map(|v| serde_json::from_value::(v).ok()) + .collect::>(); + let mut sanitized_mapping = HashMap::::new(); + if let Some(load_msg) = process_tool_specs( + conv_id_clone.as_str(), + &server_name, + is_in_display, + &mut specs, + &mut sanitized_mapping, + ®ex, + ) { + if let Err(e) = load_msg_sender.send(load_msg) { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + is_in_display = false; + } + } + }, + UpdateEventMessage::PromptsListResult { server_name, result } => {}, + UpdateEventMessage::ResourcesListResult { server_name, result } => {}, + UpdateEventMessage::ResouceTemplatesListResult { server_name, result } => {}, + UpdateEventMessage::DisplayTaskEnded => { + is_in_display = false; + }, + } + } + }); for (mut name, init_res) in pre_initialized { let _ = tx.send(LoadingMsg::Add(name.clone())); match init_res { - Ok(client) => { + Ok(mut client) => { + let messenger = messenger_builder.build_with_name(client.get_server_name().to_owned()); + client.assign_messenger(Box::new(messenger)); let mut client = Arc::new(client); while let Some(collided_client) = clients.insert(name.clone(), client) { // to avoid server name collision we are going to circumvent this by @@ -459,6 +530,8 @@ enum OutOfSpecName { EmptyDescription(String), } +type NewToolSpecs = Arc, Vec)>>>; + #[derive(Default)] /// Manages the lifecycle and interactions with tools from various sources, including MCP servers. /// This struct is responsible for initializing tools, handling tool requests, and maintaining @@ -472,11 +545,9 @@ pub struct ToolManager { /// These clients are used to communicate with MCP servers. pub clients: HashMap>, - /// Map of server names to client instances that are currently being initialized. - /// This tracks MCP server clients that are in the process of being set up but are not yet - /// fully ready for use. Once initialization is complete, these clients will be moved to - /// the main `clients` collection. - _clients_in_progress: Arc>>>, + pub has_new_stuff: Arc, + + new_tool_specs: NewToolSpecs, /// Cache for prompts collected from different servers. /// Key: prompt name @@ -487,7 +558,7 @@ pub struct ToolManager { /// Handle to the thread that displays loading status for tool initialization. /// This thread provides visual feedback to users during the tool loading process. - loading_display_task: Option>>, + loading_display_task: Option>>, /// Channel sender for communicating with the loading display thread. /// Used to send status updates about tool initialization progress. @@ -509,8 +580,8 @@ impl ToolManager { let tx = self.loading_status_sender.take(); let display_task = self.loading_display_task.take(); let tool_specs = { - let mut tool_specs = - serde_json::from_str::>(include_str!("tools/tool_index.json"))?; + let tool_specs = + serde_json::from_str::>(include_str!("../tools/tool_index.json"))?; if !crate::cli::chat::tools::thinking::Thinking::is_enabled() { tool_specs.remove("q_think_tool"); } @@ -518,139 +589,91 @@ impl ToolManager { }; let conversation_id = self.conversation_id.clone(); let regex = Arc::new(regex::Regex::new(VALID_TOOL_NAME)?); - let load_tool = self + self.new_tool_specs = Arc::new(Mutex::new(HashMap::new())); + let load_tools = self .clients - .iter() - .map(|(server_name, client)| { - let client_clone = client.clone(); - let server_name_clone = server_name.clone(); - let tx_clone = tx.clone(); - let regex_clone = regex.clone(); - let tool_specs_clone = tool_specs.clone(); - let conversation_id = conversation_id.clone(); - async move { - let tool_spec = client_clone.init().await; - let mut sanitized_mapping = HashMap::::new(); - match tool_spec { - Ok((server_name, specs)) => { - // Each mcp server might have multiple tools. - // To avoid naming conflicts we are going to namespace it. - // This would also help us locate which mcp server to call the tool from. - let mut out_of_spec_tool_names = Vec::::new(); - let mut hasher = DefaultHasher::new(); - let number_of_tools = specs.len(); - // Sanitize tool names to ensure they comply with the naming requirements: - // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use it as is - // 2. Otherwise, remove invalid characters and handle special cases: - // - Remove namespace delimiters - // - Ensure the name starts with an alphabetic character - // - Generate a hash-based name if the sanitized result is empty - // This ensures all tool names are valid identifiers that can be safely used in the system - // If after all of the aforementioned modification the combined tool - // name we have exceeds a length of 64, we surface it as an error - for mut spec in specs { - let sn = if !regex_clone.is_match(&spec.name) { - let mut sn = sanitize_name(spec.name.clone(), ®ex_clone, &mut hasher); - while sanitized_mapping.contains_key(&sn) { - sn.push('1'); - } - sn - } else { - spec.name.clone() - }; - let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); - if full_name.len() > 64 { - out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name)); - continue; - } else if spec.description.is_empty() { - out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name)); - continue; - } - if sn != spec.name { - sanitized_mapping.insert(full_name.clone(), format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name)); - } - spec.name = full_name; - spec.tool_origin = ToolOrigin::McpServer(server_name.clone()); - tool_specs_clone.lock().await.insert(spec.name.clone(), spec); - } - - // Send server load success metric datum - send_mcp_server_init(conversation_id, None, number_of_tools).await; - - // Tool name translation. This is beyond of the scope of what is - // considered a "server load". Reasoning being: - // - Failures here are not related to server load - // - There is not a whole lot we can do with this data - if let Some(tx_clone) = &tx_clone { - let send_result = if !out_of_spec_tool_names.is_empty() { - let msg = out_of_spec_tool_names.iter().fold( - String::from("The following tools are out of spec. They will be excluded from the list of available tools:\n"), - |mut acc, name| { - let (tool_name, msg) = match name { - OutOfSpecName::TooLong(tool_name) => (tool_name.as_str(), "tool name exceeds max length of 64 when combined with server name"), - OutOfSpecName::IllegalChar(tool_name) => (tool_name.as_str(), "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$"), - OutOfSpecName::EmptyDescription(tool_name) => (tool_name.as_str(), "tool schema contains empty description"), - }; - acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); - acc - } - ); - tx_clone.send(LoadingMsg::Error { - name: server_name.clone(), - msg: eyre::eyre!(msg), - }) - // TODO: if no tools are valid, we need to offload the server - // from the fleet (i.e. kill the server) - } else if !sanitized_mapping.is_empty() { - let warn = sanitized_mapping.iter().fold(String::from("The following tool names are changed:\n"), |mut acc, (k, v)| { - acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); - acc - }); - tx_clone.send(LoadingMsg::Warn { - name: server_name.clone(), - msg: eyre::eyre!(warn), - }) - } else { - tx_clone.send(LoadingMsg::Done(server_name.clone())) - }; - if let Err(e) = send_result { - error!("Error while sending status update to display task: {:?}", e); - } - } - }, - Err(e) => { - error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); - let init_failure_reason = Some(e.to_string()); - send_mcp_server_init(conversation_id, init_failure_reason, 0).await; - if let Some(tx_clone) = &tx_clone { - if let Err(e) = tx_clone.send(LoadingMsg::Error { - name: server_name_clone, - msg: e, - }) { - error!("Error while sending status update to display task: {:?}", e); - } - } - }, - } - Ok::<_, eyre::Report>(Some(sanitized_mapping)) - } + .values() + .map(|c| { + let clone = Arc::clone(c); + async move { clone.init().await } }) .collect::>(); - // TODO: do we want to introduce a timeout here? - self.tn_map = stream::iter(load_tool) - .map(|async_closure| tokio::task::spawn(async_closure)) + let some = stream::iter(load_tools) + .map(|async_closure| tokio::spawn(async_closure)) .buffer_unordered(20) .collect::>() - .await - .into_iter() - .filter_map(|r| r.ok()) - .filter_map(|r| r.ok()) - .flatten() - .flatten() - .collect::>(); + .await; + // let load_tool = self + // .clients + // .iter() + // .map(|(server_name, client)| { + // let client_clone = client.clone(); + // let server_name_clone = server_name.clone(); + // let tx_clone = tx.clone(); + // let regex_clone = regex.clone(); + // let tool_specs_clone = tool_specs.clone(); + // let conversation_id = conversation_id.clone(); + // async move { + // let tool_spec = client_clone.init().await; + // let mut sanitized_mapping = HashMap::::new(); + // match tool_spec { + // Ok((server_name, mut specs)) => { + // let msg = process_tool_specs( + // conversation_id.as_str(), + // &server_name, + // true, + // &mut specs, + // &mut sanitized_mapping, + // ®ex_clone, + // ); + // for spec in specs { + // tool_specs_clone.lock().await.insert(spec.name.clone(), spec); + // } + // if let (Some(msg), Some(tx)) = (msg, &tx_clone) { + // let _ = tx.send(msg); + // } + // }, + // Err(e) => { + // error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); + // let init_failure_reason = Some(e.to_string()); + // tokio::spawn(async move { + // let event = fig_telemetry::EventType::McpServerInit { + // conversation_id, + // init_failure_reason, + // number_of_tools: 0, + // }; + // let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; + // fig_telemetry::dispatch_or_send_event(app_event).await; + // }); + // if let Some(tx_clone) = &tx_clone { + // if let Err(e) = tx_clone.send(LoadingMsg::Error { + // name: server_name_clone, + // msg: e, + // }) { + // error!("Error while sending status update to display task: {:?}", e); + // } + // } + // }, + // } + // Ok::<_, eyre::Report>(Some(sanitized_mapping)) + // } + // }) + // .collect::>(); + // // TODO: do we want to introduce a timeout here? + // self.tn_map = stream::iter(load_tool) + // .map(|async_closure| tokio::task::spawn(async_closure)) + // .buffer_unordered(20) + // .collect::>() + // .await + // .into_iter() + // .filter_map(|r| r.ok()) + // .filter_map(|r| r.ok()) + // .flatten() + // .flatten() + // .collect::>(); drop(tx); if let Some(display_task) = display_task { - if let Err(e) = display_task.join() { + if let Err(e) = display_task.await { error!("Error while joining status display task: {:?}", e); } } @@ -886,6 +909,134 @@ impl ToolManager { } } +#[inline] +fn process_tool_specs( + conversation_id: &str, + server_name: &str, + is_in_display: bool, + specs: &mut Vec, + tn_map: &mut HashMap, + regex: &Arc, +) -> Option { + // Each mcp server might have multiple tools. + // To avoid naming conflicts we are going to namespace it. + // This would also help us locate which mcp server to call the tool from. + let mut out_of_spec_tool_names = Vec::::new(); + let mut hasher = DefaultHasher::new(); + let number_of_tools = specs.len(); + // Sanitize tool names to ensure they comply with the naming requirements: + // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use + // it as is + // 2. Otherwise, remove invalid characters and handle special cases: + // - Remove namespace delimiters + // - Ensure the name starts with an alphabetic character + // - Generate a hash-based name if the sanitized result is empty + // This ensures all tool names are valid identifiers that can be safely used in the system + // If after all of the aforementioned modification the combined tool + // name we have exceeds a length of 64, we surface it as an error + for spec in specs { + let sn = if !regex.is_match(&spec.name) { + let mut sn = sanitize_name(spec.name.clone(), regex, &mut hasher); + while tn_map.contains_key(&sn) { + sn.push('1'); + } + sn + } else { + spec.name.clone() + }; + let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); + if full_name.len() > 64 { + out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); + continue; + } else if spec.description.is_empty() { + out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); + continue; + } + if sn != spec.name { + tn_map.insert( + full_name.clone(), + format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name), + ); + } + spec.name = full_name; + spec.tool_origin = ToolOrigin::McpServer(server_name.to_string()); + } + // Send server load success metric datum + let conversation_id = conversation_id.to_string(); + tokio::spawn(async move { + let event = fig_telemetry::EventType::McpServerInit { + conversation_id, + init_failure_reason: None, + number_of_tools, + }; + let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; + fig_telemetry::dispatch_or_send_event(app_event).await; + }); + // Tool name translation. This is beyond of the scope of what is + // considered a "server load". Reasoning being: + // - Failures here are not related to server load + // - There is not a whole lot we can do with this data + let loading_msg = if !out_of_spec_tool_names.is_empty() { + let msg = out_of_spec_tool_names.iter().fold( + String::from( + "The following tools are out of spec. They will be excluded from the list of available tools:\n", + ), + |mut acc, name| { + let (tool_name, msg) = match name { + OutOfSpecName::TooLong(tool_name) => ( + tool_name.as_str(), + "tool name exceeds max length of 64 when combined with server name", + ), + OutOfSpecName::IllegalChar(tool_name) => ( + tool_name.as_str(), + "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$", + ), + OutOfSpecName::EmptyDescription(tool_name) => { + (tool_name.as_str(), "tool schema contains empty description") + }, + }; + acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); + acc + }, + ); + error!( + "Server {} finished loading with the following error: \n{}", + server_name, msg + ); + if is_in_display { + Some(LoadingMsg::Error { + name: server_name.to_string(), + msg: eyre::eyre!(msg), + }) + } else { + None + } + // TODO: if no tools are valid, we need to offload the server + // from the fleet (i.e. kill the server) + } else if !tn_map.is_empty() { + let warn = tn_map.iter().fold( + String::from("The following tool names are changed:\n"), + |mut acc, (k, v)| { + acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); + acc + }, + ); + if is_in_display { + Some(LoadingMsg::Warn { + name: server_name.to_string(), + msg: eyre::eyre!(warn), + }) + } else { + None + } + } else if is_in_display { + Some(LoadingMsg::Done(server_name.to_string())) + } else { + None + }; + loading_msg +} + fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { if regex.is_match(&orig) && !orig.contains(NAMESPACE_DELIMITER) { return orig; @@ -993,6 +1144,19 @@ fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) - )?) } +fn queue_incomplete_load_message(msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::ResetColor, + // We expect the message start with a newline + style::Print("following servers are still loading:"), + style::Print(msg), + style::ResetColor, + )?) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index cf80062e54..a6fbca2586 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -103,17 +103,26 @@ impl CustomToolClient { let is_tool_supported = cap.tools.is_some(); server_capabilities.write().await.replace(cap); // Assuming a shape of return as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#listing-tools - let tools = if is_tool_supported { - // And now we make the server tell us what tools they have - let resp = client.request("tools/list", None).await?; - match resp.result.and_then(|r| r.get("tools").cloned()) { - Some(value) => serde_json::from_value::>(value)?, - None => Default::default(), - } - } else { - Default::default() - }; - Ok((server_name.clone(), tools)) + // let tools = if is_tool_supported { + // // And now we make the server tell us what tools they have + // let resp = client.request("tools/list", None).await?; + // match resp.result.and_then(|r| r.get("tools").cloned()) { + // Some(value) => serde_json::from_value::>(value)?, + // None => Default::default(), + // } + // } else { + // Default::default() + // }; + Ok((server_name.clone(), vec![])) + }, + } + } + + pub fn assign_messenger(&mut self, messenger: Box) { + tracing::error!("## background: assigned {} with messenger", self.get_server_name()); + match self { + CustomToolClient::Stdio { client, .. } => { + client.messenger = Some(messenger); }, } } diff --git a/crates/mcp_client/src/client.rs b/crates/mcp_client/src/client.rs index 9b22b7a1cb..eae73de53a 100644 --- a/crates/mcp_client/src/client.rs +++ b/crates/mcp_client/src/client.rs @@ -365,9 +365,15 @@ where }); } if let (Some(_), Some(messenger)) = (&cap.tools, &self.messenger) { + tracing::error!( + "## background: {} is spawning background task to fetch tools", + self.server_name + ); let client_ref = (*self).clone(); let msger = messenger.duplicate(); tokio::spawn(async move { + // TODO: decouple pagination logic from request and have page fetching logic here + // instead let resp = match client_ref.request("tools/list", None).await { Ok(resp) => resp, Err(e) => { diff --git a/crates/q_chat/Cargo.toml b/crates/q_chat/Cargo.toml new file mode 100644 index 0000000000..8dba7b3117 --- /dev/null +++ b/crates/q_chat/Cargo.toml @@ -0,0 +1,58 @@ +[package] +name = "q_chat" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +publish.workspace = true +version.workspace = true +license.workspace = true + +[dependencies] +async-trait.workspace = true +anstream.workspace = true +aws-smithy-types = "1.2.10" +bstr.workspace = true +clap.workspace = true +color-print.workspace = true +convert_case.workspace = true +crossterm.workspace = true +eyre.workspace = true +fig_api_client.workspace = true +fig_auth.workspace = true +fig_diagnostic.workspace = true +fig_os_shim.workspace = true +fig_settings.workspace = true +fig_telemetry.workspace = true +fig_util.workspace = true +futures.workspace = true +glob.workspace = true +mcp_client.workspace = true +rand.workspace = true +regex.workspace = true +rustyline = { version = "15.0.0", features = ["derive", "custom-bindings"] } +serde.workspace = true +serde_json.workspace = true +shell-color.workspace = true +shell-words = "1.1" +shellexpand.workspace = true +shlex.workspace = true +similar.workspace = true +skim = "0.16.2" +spinners.workspace = true +syntect = { version = "5.2.0", features = [ "default-syntaxes", "default-themes" ]} +tempfile.workspace = true +thiserror.workspace = true +time.workspace = true +tokio.workspace = true +tracing.workspace = true +unicode-width.workspace = true +url.workspace = true +uuid.workspace = true +winnow.workspace = true +strip-ansi-escapes = "0.2.1" + +[dev-dependencies] +tracing-subscriber.workspace = true + +[lints] +workspace = true diff --git a/crates/q_chat/src/tool_manager/mod.rs b/crates/q_chat/src/tool_manager/mod.rs new file mode 100644 index 0000000000..6251b7fb77 --- /dev/null +++ b/crates/q_chat/src/tool_manager/mod.rs @@ -0,0 +1,4 @@ +mod server_messenger; +pub mod tool_manager; + +pub use tool_manager::*; diff --git a/crates/q_chat/src/tool_manager/server_messenger.rs b/crates/q_chat/src/tool_manager/server_messenger.rs new file mode 100644 index 0000000000..aad019bc16 --- /dev/null +++ b/crates/q_chat/src/tool_manager/server_messenger.rs @@ -0,0 +1,116 @@ +use mcp_client::{ + Messenger, + MessengerError, + PromptsListResult, + ResourceTemplatesListResult, + ResourcesListResult, + ToolsListResult, +}; +use tokio::sync::mpsc::{ + Receiver, + Sender, + channel, +}; + +#[derive(Clone, Debug)] +pub enum UpdateEventMessage { + ToolsListResult { + server_name: String, + result: ToolsListResult, + }, + PromptsListResult { + server_name: String, + result: PromptsListResult, + }, + ResourcesListResult { + server_name: String, + result: ResourcesListResult, + }, + ResouceTemplatesListResult { + server_name: String, + result: ResourceTemplatesListResult, + }, + DisplayTaskEnded, +} + +#[derive(Clone, Debug)] +pub struct ServerMessengerBuilder { + pub update_event_sender: Sender, +} + +impl ServerMessengerBuilder { + pub fn new(capacity: usize) -> (Receiver, Self) { + let (tx, rx) = channel::(capacity); + let this = Self { + update_event_sender: tx, + }; + (rx, this) + } + + pub fn build_with_name(&self, server_name: String) -> ServerMessenger { + ServerMessenger { + server_name, + update_event_sender: self.update_event_sender.clone(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ServerMessenger { + pub server_name: String, + pub update_event_sender: Sender, +} + +#[async_trait::async_trait] +impl Messenger for ServerMessenger { + async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::ToolsListResult { + server_name: self.server_name.clone(), + result, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + async fn send_prompts_list_result(&self, result: PromptsListResult) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::PromptsListResult { + server_name: self.server_name.clone(), + result, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + async fn send_resources_list_result(&self, result: ResourcesListResult) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::ResourcesListResult { + server_name: self.server_name.clone(), + result, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + async fn send_resource_templates_list_result( + &self, + result: ResourceTemplatesListResult, + ) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::ResouceTemplatesListResult { + server_name: self.server_name.clone(), + result, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + fn duplicate(&self) -> Box { + Box::new(self.clone()) + } +} diff --git a/crates/q_chat/src/tool_manager/tool_manager.rs b/crates/q_chat/src/tool_manager/tool_manager.rs new file mode 100644 index 0000000000..faea281d31 --- /dev/null +++ b/crates/q_chat/src/tool_manager/tool_manager.rs @@ -0,0 +1,1332 @@ +use std::collections::HashMap; +use std::hash::{ + DefaultHasher, + Hasher, +}; +use std::io::Write; +use std::path::PathBuf; +use std::sync::atomic::AtomicBool; +use std::sync::mpsc::RecvTimeoutError; +use std::sync::{ + Arc, + RwLock as SyncRwLock, +}; + +use convert_case::Casing; +use crossterm::{ + cursor, + execute, + queue, + style, + terminal, +}; +use futures::{ + StreamExt, + stream, +}; +<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs +======== +use mcp_client::{ + JsonRpcResponse, + PromptGet, +}; +use regex::Regex; +>>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs +use serde::{ + Deserialize, + Serialize, +}; +use thiserror::Error; +use tokio::sync::Mutex; +use tracing::{ + error, + warn, +}; + +<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs +use super::command::PromptsGetCommand; +use super::message::AssistantToolUse; +use super::tools::custom_tool::{ +======== +use crate::command::PromptsGetCommand; +use crate::message::AssistantToolUse; +use crate::tool_manager::server_messenger::{ + ServerMessengerBuilder, + UpdateEventMessage, +}; +use crate::tools::custom_tool::{ +>>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs + CustomTool, + CustomToolClient, + CustomToolConfig, +}; +<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs +use super::tools::execute_bash::ExecuteBash; +use super::tools::fs_read::FsRead; +use super::tools::fs_write::FsWrite; +use super::tools::gh_issue::GhIssue; +use super::tools::thinking::Thinking; +use super::tools::use_aws::UseAws; +use super::tools::{ +======== +use crate::tools::execute_bash::ExecuteBash; +use crate::tools::fs_read::FsRead; +use crate::tools::fs_write::FsWrite; +use crate::tools::gh_issue::GhIssue; +use crate::tools::use_aws::UseAws; +use crate::tools::{ +>>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs + Tool, + ToolOrigin, + ToolSpec, +}; +<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs +use crate::api_client::model::{ + ToolResult, + ToolResultContentBlock, + ToolResultStatus, +}; +use crate::mcp_client::{ + JsonRpcResponse, + NullMessenger, + PromptGet, +}; +use crate::telemetry::send_mcp_server_init; +======== +>>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs + +const NAMESPACE_DELIMITER: &str = "___"; +// This applies for both mcp server and tool name since in the end the tool name as seen by the +// model is just {server_name}{NAMESPACE_DELIMITER}{tool_name} +const VALID_TOOL_NAME: &str = "^[a-zA-Z][a-zA-Z0-9_]*$"; +const SPINNER_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; + +#[derive(Debug, Error)] +pub enum GetPromptError { + #[error("Prompt with name {0} does not exist")] + PromptNotFound(String), + #[error("Prompt {0} is offered by more than one server. Use one of the following {1}")] + AmbiguousPrompt(String, String), + #[error("Missing client")] + MissingClient, + #[error("Missing prompt name")] + MissingPromptName, + #[error("Synchronization error: {0}")] + Synchronization(String), + #[error("Missing prompt bundle")] + MissingPromptInfo, + #[error(transparent)] + General(#[from] eyre::Report), +} + +/// Messages used for communication between the tool initialization thread and the loading +/// display thread. These messages control the visual loading indicators shown to +/// the user during tool initialization. +pub enum LoadingMsg { + /// Indicates a new tool is being initialized and should be added to the loading + /// display. The String parameter is the name of the tool being initialized. + Add(String), + /// Indicates a tool has finished initializing successfully and should be removed from + /// the loading display. The String parameter is the name of the tool that + /// completed initialization. + Done(String), + /// Represents an error that occurred during tool initialization. + /// Contains the name of the server that failed to initialize and the error message. + Error { name: String, msg: eyre::Report }, + /// Represents a warning that occurred during tool initialization. + /// Contains the name of the server that generated the warning and the warning message. + Warn { name: String, msg: eyre::Report }, + /// Signals that the loading display thread should terminate. + /// This is sent when all tool initialization is complete or when the application is shutting + /// down. + Terminate, +} + +/// Represents the state of a loading indicator for a tool being initialized. +/// +/// This struct tracks timing information for each tool's loading status display in the terminal. +/// +/// # Fields +/// * `init_time` - When initialization for this tool began, used to calculate load time +struct StatusLine { + init_time: std::time::Instant, +} + +// This is to mirror claude's config set up +#[derive(Clone, Serialize, Deserialize, Debug, Default)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfig { + mcp_servers: HashMap, +} + +impl McpServerConfig { + pub async fn load_config(output: &mut impl Write) -> eyre::Result { + let mut cwd = std::env::current_dir()?; + cwd.push(".amazonq/mcp.json"); + let expanded_path = shellexpand::tilde("~/.aws/amazonq/mcp.json"); + let global_path = PathBuf::from(expanded_path.as_ref()); + let global_buf = tokio::fs::read(global_path).await.ok(); + let local_buf = tokio::fs::read(cwd).await.ok(); + let conf = match (global_buf, local_buf) { + (Some(global_buf), Some(local_buf)) => { + let mut global_conf = Self::from_slice(&global_buf, output, "global")?; + let local_conf = Self::from_slice(&local_buf, output, "local")?; + for (server_name, config) in local_conf.mcp_servers { + if global_conf.mcp_servers.insert(server_name.clone(), config).is_some() { + queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print("MCP config conflict for "), + style::SetForegroundColor(style::Color::Green), + style::Print(server_name), + style::ResetColor, + style::Print(". Using workspace version.\n") + )?; + } + } + global_conf + }, + (None, Some(local_buf)) => Self::from_slice(&local_buf, output, "local")?, + (Some(global_buf), None) => Self::from_slice(&global_buf, output, "global")?, + _ => Default::default(), + }; + output.flush()?; + Ok(conf) + } + + fn from_slice(slice: &[u8], output: &mut impl Write, location: &str) -> eyre::Result { + match serde_json::from_slice::(slice) { + Ok(config) => Ok(config), + Err(e) => { + queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print(format!("Error reading {location} mcp config: {e}\n")), + style::Print("Please check to make sure config is correct. Discarding.\n"), + )?; + Ok(McpServerConfig::default()) + }, + } + } +} + +#[derive(Default)] +pub struct ToolManagerBuilder { + mcp_server_config: Option, + prompt_list_sender: Option>>, + prompt_list_receiver: Option>>, + conversation_id: Option, +} + +impl ToolManagerBuilder { + pub fn mcp_server_config(mut self, config: McpServerConfig) -> Self { + self.mcp_server_config.replace(config); + self + } + + pub fn prompt_list_sender(mut self, sender: std::sync::mpsc::Sender>) -> Self { + self.prompt_list_sender.replace(sender); + self + } + + pub fn prompt_list_receiver(mut self, receiver: std::sync::mpsc::Receiver>) -> Self { + self.prompt_list_receiver.replace(receiver); + self + } + + pub fn conversation_id(mut self, conversation_id: &str) -> Self { + self.conversation_id.replace(conversation_id.to_string()); + self + } + + pub async fn build(mut self) -> eyre::Result { + let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; + debug_assert!(self.conversation_id.is_some()); + let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; + let regex = regex::Regex::new(VALID_TOOL_NAME)?; + let mut hasher = DefaultHasher::new(); + let pre_initialized = mcp_servers + .into_iter() + .map(|(server_name, server_config)| { + let snaked_cased_name = server_name.to_case(convert_case::Case::Snake); + let sanitized_server_name = sanitize_name(snaked_cased_name, ®ex, &mut hasher); + let custom_tool_client = CustomToolClient::from_config(sanitized_server_name.clone(), server_config); + (sanitized_server_name, custom_tool_client) + }) + .collect::>(); + + // Send up task to update user on server loading status + let (tx, rx) = std::sync::mpsc::channel::(); + // Using a hand rolled thread because it's just easier to do this than do deal with the Send + // requirements that comes with holding onto the stdout lock. + let loading_display_task = tokio::task::spawn_blocking(move || { + let stdout = std::io::stdout(); + let mut stdout_lock = stdout.lock(); + let mut loading_servers = HashMap::::new(); + let mut spinner_logo_idx: usize = 0; + let mut complete: usize = 0; + let mut failed: usize = 0; + loop { + match rx.recv_timeout(std::time::Duration::from_millis(50)) { + Ok(recv_result) => match recv_result { + LoadingMsg::Add(name) => { + let init_time = std::time::Instant::now(); + let status_line = StatusLine { init_time }; + execute!(stdout_lock, cursor::MoveToColumn(0))?; + if !loading_servers.is_empty() { + // TODO: account for terminal width + execute!(stdout_lock, cursor::MoveUp(1))?; + } + loading_servers.insert(name.clone(), status_line); + let total = loading_servers.len(); + execute!(stdout_lock, terminal::Clear(terminal::ClearType::CurrentLine))?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + stdout_lock.flush()?; + }, + LoadingMsg::Done(name) => { + if let Some(status_line) = loading_servers.get(&name) { + complete += 1; + let time_taken = + (std::time::Instant::now() - status_line.init_time).as_secs_f64().abs(); + let time_taken = format!("{:.2}", time_taken); + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_success_message(&name, &time_taken, &mut stdout_lock)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + stdout_lock.flush()?; + } + }, + LoadingMsg::Error { name, msg } => { + failed += 1; + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_failure_message(&name, &msg, &mut stdout_lock)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + }, + LoadingMsg::Warn { name, msg } => { + complete += 1; + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = eyre::eyre!(msg.to_string()); + queue_warn_message(&name, &msg, &mut stdout_lock)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + stdout_lock.flush()?; + }, + LoadingMsg::Terminate => { + if !loading_servers.is_empty() { + let msg = loading_servers.iter().fold(String::new(), |mut acc, (server_name, _)| { + acc.push_str(format!("\n - {server_name}").as_str()); + acc + }); + let msg = eyre::eyre!(msg); + queue_incomplete_load_message(&msg, &mut stdout_lock)?; + } + break; + }, + }, + Err(RecvTimeoutError::Timeout) => { + spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); + execute!( + stdout_lock, + cursor::SavePosition, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + style::Print(SPINNER_CHARS[spinner_logo_idx]), + cursor::RestorePosition + )?; + }, + _ => break, + } + } + Ok::<_, eyre::Report>(()) + }); + let mut clients = HashMap::>::new(); + let load_msg_sender = tx.clone(); + let conv_id_clone = conversation_id.clone(); + let regex = Arc::new(Regex::new(VALID_TOOL_NAME)?); + let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); + tokio::spawn(async move { + let mut is_in_display = true; + while let Some(msg) = msg_rx.recv().await { + // For now we will treat every list result as if they contain the + // complete set of tools. This is not necessarily true in the future when + // request method on the mcp client no longer buffers all the pages from + // list calls. + match msg { + UpdateEventMessage::ToolsListResult { server_name, result } => { + error!("## background: from {server_name}: {:?}", result); + let mut specs = result + .tools + .into_iter() + .filter_map(|v| serde_json::from_value::(v).ok()) + .collect::>(); + let mut sanitized_mapping = HashMap::::new(); + if let Some(load_msg) = process_tool_specs( + conv_id_clone.as_str(), + &server_name, + is_in_display, + &mut specs, + &mut sanitized_mapping, + ®ex, + ) { + if let Err(e) = load_msg_sender.send(load_msg) { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + is_in_display = false; + } + } + }, + UpdateEventMessage::PromptsListResult { server_name, result } => {}, + UpdateEventMessage::ResourcesListResult { server_name, result } => {}, + UpdateEventMessage::ResouceTemplatesListResult { server_name, result } => {}, + UpdateEventMessage::DisplayTaskEnded => { + is_in_display = false; + }, + } + } + }); + for (mut name, init_res) in pre_initialized { + let _ = tx.send(LoadingMsg::Add(name.clone())); + match init_res { + Ok(mut client) => { + let messenger = messenger_builder.build_with_name(client.get_server_name().to_owned()); + client.assign_messenger(Box::new(messenger)); + let mut client = Arc::new(client); + while let Some(collided_client) = clients.insert(name.clone(), client) { + // to avoid server name collision we are going to circumvent this by + // appending the name with 1 + name.push('1'); + client = collided_client; + } + }, + Err(e) => { + error!("Error initializing mcp client for server {}: {:?}", name, &e); + send_mcp_server_init(conversation_id.clone(), Some(e.to_string()), 0).await; + + let _ = tx.send(LoadingMsg::Error { + name: name.clone(), + msg: e, + }); + }, + } + } + let loading_display_task = Some(loading_display_task); + let loading_status_sender = Some(tx); + + // Set up task to handle prompt requests + let sender = self.prompt_list_sender.take(); + let receiver = self.prompt_list_receiver.take(); + let prompts = Arc::new(SyncRwLock::new(HashMap::default())); + // TODO: accommodate hot reload of mcp servers + if let (Some(sender), Some(receiver)) = (sender, receiver) { + let clients = clients.iter().fold(HashMap::new(), |mut acc, (n, c)| { + acc.insert(n.to_string(), Arc::downgrade(c)); + acc + }); + let prompts_clone = prompts.clone(); + tokio::task::spawn_blocking(move || { + let receiver = Arc::new(std::sync::Mutex::new(receiver)); + loop { + let search_word = receiver.lock().map_err(|e| eyre::eyre!("{:?}", e))?.recv()?; + if clients + .values() + .any(|client| client.upgrade().is_some_and(|c| c.is_prompts_out_of_date())) + { + let mut prompts_wl = prompts_clone.write().map_err(|e| { + eyre::eyre!( + "Error retrieving write lock on prompts for tab complete {}", + e.to_string() + ) + })?; + *prompts_wl = clients.iter().fold( + HashMap::>::new(), + |mut acc, (server_name, client)| { + let Some(client) = client.upgrade() else { + return acc; + }; + let prompt_gets = client.list_prompt_gets(); + let Ok(prompt_gets) = prompt_gets.read() else { + tracing::error!("Error retrieving read lock for prompt gets for tab complete"); + return acc; + }; + for (prompt_name, prompt_get) in prompt_gets.iter() { + acc.entry(prompt_name.to_string()) + .and_modify(|bundles| { + bundles.push(PromptBundle { + server_name: server_name.to_owned(), + prompt_get: prompt_get.clone(), + }); + }) + .or_insert(vec![PromptBundle { + server_name: server_name.to_owned(), + prompt_get: prompt_get.clone(), + }]); + } + client.prompts_updated(); + acc + }, + ); + } + let prompts_rl = prompts_clone.read().map_err(|e| { + eyre::eyre!( + "Error retrieving read lock on prompts for tab complete {}", + e.to_string() + ) + })?; + let filtered_prompts = prompts_rl + .iter() + .flat_map(|(prompt_name, bundles)| { + if bundles.len() > 1 { + bundles + .iter() + .map(|b| format!("{}/{}", b.server_name, prompt_name)) + .collect() + } else { + vec![prompt_name.to_owned()] + } + }) + .filter(|n| { + if let Some(p) = &search_word { + n.contains(p) + } else { + true + } + }) + .collect::>(); + if let Err(e) = sender.send(filtered_prompts) { + error!("Error sending prompts to chat helper: {:?}", e); + } + } + #[allow(unreachable_code)] + Ok::<(), eyre::Report>(()) + }); + } + + Ok(ToolManager { + conversation_id, + clients, + prompts, + loading_display_task, + loading_status_sender, + ..Default::default() + }) + } +} + +#[derive(Clone, Debug)] +/// A collection of information that is used for the following purposes: +/// - Checking if prompt info cached is out of date +/// - Retrieve new prompt info +pub struct PromptBundle { + /// The server name from which the prompt is offered / exposed + pub server_name: String, + /// The prompt get (info with which a prompt is retrieved) cached + pub prompt_get: PromptGet, +} + +/// Categorizes different types of tool name validation failures: +/// - `TooLong`: The tool name exceeds the maximum allowed length +/// - `IllegalChar`: The tool name contains characters that are not allowed +/// - `EmptyDescription`: The tool description is empty or missing +#[allow(dead_code)] +enum OutOfSpecName { + TooLong(String), + IllegalChar(String), + EmptyDescription(String), +} + +type NewToolSpecs = Arc, Vec)>>>; + +#[derive(Default)] +/// Manages the lifecycle and interactions with tools from various sources, including MCP servers. +/// This struct is responsible for initializing tools, handling tool requests, and maintaining +/// a cache of available prompts from connected servers. +pub struct ToolManager { + /// Unique identifier for the current conversation. + /// This ID is used to track and associate tools with a specific chat session. + pub conversation_id: String, + + /// Map of server names to their corresponding client instances. + /// These clients are used to communicate with MCP servers. + pub clients: HashMap>, + + pub has_new_stuff: Arc, + + new_tool_specs: NewToolSpecs, + + /// Cache for prompts collected from different servers. + /// Key: prompt name + /// Value: a list of PromptBundle that has a prompt of this name. + /// This cache helps resolve prompt requests efficiently and handles + /// cases where multiple servers offer prompts with the same name. + pub prompts: Arc>>>, + + /// Handle to the thread that displays loading status for tool initialization. + /// This thread provides visual feedback to users during the tool loading process. + loading_display_task: Option>>, + + /// Channel sender for communicating with the loading display thread. + /// Used to send status updates about tool initialization progress. + loading_status_sender: Option>, + + /// Mapping from sanitized tool names to original tool names. + /// This is used to handle tool name transformations that may occur during initialization + /// to ensure tool names comply with naming requirements. + pub tn_map: HashMap, + + /// A cache of tool's input schema for all of the available tools. + /// This is mainly used to show the user what the tools look like from the perspective of the + /// model. + pub schema: HashMap, +} + +impl ToolManager { + pub async fn load_tools(&mut self) -> eyre::Result> { + let tx = self.loading_status_sender.take(); + let display_task = self.loading_display_task.take(); + let tool_specs = { +<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs + let mut tool_specs = + serde_json::from_str::>(include_str!("tools/tool_index.json"))?; + if !crate::cli::chat::tools::thinking::Thinking::is_enabled() { + tool_specs.remove("q_think_tool"); + } +======== + let tool_specs = + serde_json::from_str::>(include_str!("../tools/tool_index.json"))?; +>>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs + Arc::new(Mutex::new(tool_specs)) + }; + let conversation_id = self.conversation_id.clone(); + let regex = Arc::new(regex::Regex::new(VALID_TOOL_NAME)?); + self.new_tool_specs = Arc::new(Mutex::new(HashMap::new())); + let load_tools = self + .clients +<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs + .iter() + .map(|(server_name, client)| { + let client_clone = client.clone(); + let server_name_clone = server_name.clone(); + let tx_clone = tx.clone(); + let regex_clone = regex.clone(); + let tool_specs_clone = tool_specs.clone(); + let conversation_id = conversation_id.clone(); + async move { + let tool_spec = client_clone.init().await; + let mut sanitized_mapping = HashMap::::new(); + match tool_spec { + Ok((server_name, specs)) => { + // Each mcp server might have multiple tools. + // To avoid naming conflicts we are going to namespace it. + // This would also help us locate which mcp server to call the tool from. + let mut out_of_spec_tool_names = Vec::::new(); + let mut hasher = DefaultHasher::new(); + let number_of_tools = specs.len(); + // Sanitize tool names to ensure they comply with the naming requirements: + // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use it as is + // 2. Otherwise, remove invalid characters and handle special cases: + // - Remove namespace delimiters + // - Ensure the name starts with an alphabetic character + // - Generate a hash-based name if the sanitized result is empty + // This ensures all tool names are valid identifiers that can be safely used in the system + // If after all of the aforementioned modification the combined tool + // name we have exceeds a length of 64, we surface it as an error + for mut spec in specs { + let sn = if !regex_clone.is_match(&spec.name) { + let mut sn = sanitize_name(spec.name.clone(), ®ex_clone, &mut hasher); + while sanitized_mapping.contains_key(&sn) { + sn.push('1'); + } + sn + } else { + spec.name.clone() + }; + let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); + if full_name.len() > 64 { + out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name)); + continue; + } else if spec.description.is_empty() { + out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name)); + continue; + } + if sn != spec.name { + sanitized_mapping.insert(full_name.clone(), format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name)); + } + spec.name = full_name; + spec.tool_origin = ToolOrigin::McpServer(server_name.clone()); + tool_specs_clone.lock().await.insert(spec.name.clone(), spec); + } + + // Send server load success metric datum + send_mcp_server_init(conversation_id, None, number_of_tools).await; + + // Tool name translation. This is beyond of the scope of what is + // considered a "server load". Reasoning being: + // - Failures here are not related to server load + // - There is not a whole lot we can do with this data + if let Some(tx_clone) = &tx_clone { + let send_result = if !out_of_spec_tool_names.is_empty() { + let msg = out_of_spec_tool_names.iter().fold( + String::from("The following tools are out of spec. They will be excluded from the list of available tools:\n"), + |mut acc, name| { + let (tool_name, msg) = match name { + OutOfSpecName::TooLong(tool_name) => (tool_name.as_str(), "tool name exceeds max length of 64 when combined with server name"), + OutOfSpecName::IllegalChar(tool_name) => (tool_name.as_str(), "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$"), + OutOfSpecName::EmptyDescription(tool_name) => (tool_name.as_str(), "tool schema contains empty description"), + }; + acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); + acc + } + ); + tx_clone.send(LoadingMsg::Error { + name: server_name.clone(), + msg: eyre::eyre!(msg), + }) + // TODO: if no tools are valid, we need to offload the server + // from the fleet (i.e. kill the server) + } else if !sanitized_mapping.is_empty() { + let warn = sanitized_mapping.iter().fold(String::from("The following tool names are changed:\n"), |mut acc, (k, v)| { + acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); + acc + }); + tx_clone.send(LoadingMsg::Warn { + name: server_name.clone(), + msg: eyre::eyre!(warn), + }) + } else { + tx_clone.send(LoadingMsg::Done(server_name.clone())) + }; + if let Err(e) = send_result { + error!("Error while sending status update to display task: {:?}", e); + } + } + }, + Err(e) => { + error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); + let init_failure_reason = Some(e.to_string()); + send_mcp_server_init(conversation_id, init_failure_reason, 0).await; + if let Some(tx_clone) = &tx_clone { + if let Err(e) = tx_clone.send(LoadingMsg::Error { + name: server_name_clone, + msg: e, + }) { + error!("Error while sending status update to display task: {:?}", e); + } + } + }, + } + Ok::<_, eyre::Report>(Some(sanitized_mapping)) + } +======== + .values() + .map(|c| { + let clone = Arc::clone(c); + async move { clone.init().await } +>>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs + }) + .collect::>(); + let some = stream::iter(load_tools) + .map(|async_closure| tokio::spawn(async_closure)) + .buffer_unordered(20) + .collect::>() + .await; + // let load_tool = self + // .clients + // .iter() + // .map(|(server_name, client)| { + // let client_clone = client.clone(); + // let server_name_clone = server_name.clone(); + // let tx_clone = tx.clone(); + // let regex_clone = regex.clone(); + // let tool_specs_clone = tool_specs.clone(); + // let conversation_id = conversation_id.clone(); + // async move { + // let tool_spec = client_clone.init().await; + // let mut sanitized_mapping = HashMap::::new(); + // match tool_spec { + // Ok((server_name, mut specs)) => { + // let msg = process_tool_specs( + // conversation_id.as_str(), + // &server_name, + // true, + // &mut specs, + // &mut sanitized_mapping, + // ®ex_clone, + // ); + // for spec in specs { + // tool_specs_clone.lock().await.insert(spec.name.clone(), spec); + // } + // if let (Some(msg), Some(tx)) = (msg, &tx_clone) { + // let _ = tx.send(msg); + // } + // }, + // Err(e) => { + // error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); + // let init_failure_reason = Some(e.to_string()); + // tokio::spawn(async move { + // let event = fig_telemetry::EventType::McpServerInit { + // conversation_id, + // init_failure_reason, + // number_of_tools: 0, + // }; + // let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; + // fig_telemetry::dispatch_or_send_event(app_event).await; + // }); + // if let Some(tx_clone) = &tx_clone { + // if let Err(e) = tx_clone.send(LoadingMsg::Error { + // name: server_name_clone, + // msg: e, + // }) { + // error!("Error while sending status update to display task: {:?}", e); + // } + // } + // }, + // } + // Ok::<_, eyre::Report>(Some(sanitized_mapping)) + // } + // }) + // .collect::>(); + // // TODO: do we want to introduce a timeout here? + // self.tn_map = stream::iter(load_tool) + // .map(|async_closure| tokio::task::spawn(async_closure)) + // .buffer_unordered(20) + // .collect::>() + // .await + // .into_iter() + // .filter_map(|r| r.ok()) + // .filter_map(|r| r.ok()) + // .flatten() + // .flatten() + // .collect::>(); + drop(tx); + if let Some(display_task) = display_task { + if let Err(e) = display_task.await { + error!("Error while joining status display task: {:?}", e); + } + } + let tool_specs = { + let mutex = + Arc::try_unwrap(tool_specs).map_err(|e| eyre::eyre!("Error unwrapping arc for tool specs {:?}", e))?; + mutex.into_inner() + }; + // caching the tool names for skim operations + for tool_name in tool_specs.keys() { + if !self.tn_map.contains_key(tool_name) { + self.tn_map.insert(tool_name.clone(), tool_name.clone()); + } + } + self.schema = tool_specs.clone(); + Ok(tool_specs) + } + + pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { + let map_err = |parse_error| ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!( + "Failed to validate tool parameters: {parse_error}. The model has either suggested tool parameters which are incompatible with the existing tools, or has suggested one or more tool that does not exist in the list of known tools." + ))], + status: ToolResultStatus::Error, + }; + + Ok(match value.name.as_str() { + "fs_read" => Tool::FsRead(serde_json::from_value::(value.args).map_err(map_err)?), + "fs_write" => Tool::FsWrite(serde_json::from_value::(value.args).map_err(map_err)?), + "execute_bash" => Tool::ExecuteBash(serde_json::from_value::(value.args).map_err(map_err)?), + "use_aws" => Tool::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), + "report_issue" => Tool::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), + "q_think_tool" => Tool::Thinking(serde_json::from_value::(value.args).map_err(map_err)?), + // Note that this name is namespaced with server_name{DELIMITER}tool_name + name => { + let name = self.tn_map.get(name).map_or(name, String::as_str); + let (server_name, tool_name) = name.split_once(NAMESPACE_DELIMITER).ok_or(ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!( + "The tool, \"{name}\" is supplied with incorrect name" + ))], + status: ToolResultStatus::Error, + })?; + let Some(client) = self.clients.get(server_name) else { + return Err(ToolResult { + tool_use_id: value.id, + content: vec![ToolResultContentBlock::Text(format!( + "The tool, \"{server_name}\" is not supported by the client" + ))], + status: ToolResultStatus::Error, + }); + }; + // The tool input schema has the shape of { type, properties }. + // The field "params" expected by MCP is { name, arguments }, where name is the + // name of the tool being invoked, + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools. + // The field "arguments" is where ToolUse::args belong. + let mut params = serde_json::Map::::new(); + params.insert("name".to_owned(), serde_json::Value::String(tool_name.to_owned())); + params.insert("arguments".to_owned(), value.args); + let params = serde_json::Value::Object(params); + let custom_tool = CustomTool { + name: tool_name.to_owned(), + client: client.clone(), + method: "tools/call".to_owned(), + params: Some(params), + }; + Tool::Custom(custom_tool) + }, + }) + } + + #[allow(clippy::await_holding_lock)] + pub async fn get_prompt(&self, get_command: PromptsGetCommand) -> Result { + let (server_name, prompt_name) = match get_command.params.name.split_once('/') { + None => (None::, Some(get_command.params.name.clone())), + Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), + }; + let prompt_name = prompt_name.ok_or(GetPromptError::MissingPromptName)?; + // We need to use a sync lock here because this lock is also used in a blocking thread, + // necessitated by the fact that said thread is also responsible for using a sync channel, + // which is itself necessitated by the fact that consumer of said channel is calling from a + // sync function + let mut prompts_wl = self + .prompts + .write() + .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; + let mut maybe_bundles = prompts_wl.get(&prompt_name); + let mut has_retried = false; + 'blk: loop { + match (maybe_bundles, server_name.as_ref(), has_retried) { + // If we have more than one eligible clients but no server name specified + (Some(bundles), None, _) if bundles.len() > 1 => { + break 'blk Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { + bundles.iter().fold("\n".to_string(), |mut acc, b| { + acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); + acc + }) + })); + }, + // Normal case where we have enough info to proceed + // Note that if bundle exists, it should never be empty + (Some(bundles), sn, _) => { + let bundle = if bundles.len() > 1 { + let Some(server_name) = sn else { + maybe_bundles = None; + continue 'blk; + }; + let bundle = bundles.iter().find(|b| b.server_name == *server_name); + match bundle { + Some(bundle) => bundle, + None => { + maybe_bundles = None; + continue 'blk; + }, + } + } else { + bundles.first().ok_or(GetPromptError::MissingPromptInfo)? + }; + let server_name = bundle.server_name.clone(); + let client = self.clients.get(&server_name).ok_or(GetPromptError::MissingClient)?; + // Here we lazily update the out of date cache + if client.is_prompts_out_of_date() { + let prompt_gets = client.list_prompt_gets(); + let prompt_gets = prompt_gets + .read() + .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; + for (prompt_name, prompt_get) in prompt_gets.iter() { + prompts_wl + .entry(prompt_name.to_string()) + .and_modify(|bundles| { + let mut is_modified = false; + for bundle in &mut *bundles { + let mut updated_bundle = PromptBundle { + server_name: server_name.clone(), + prompt_get: prompt_get.clone(), + }; + if bundle.server_name == *server_name { + std::mem::swap(bundle, &mut updated_bundle); + is_modified = true; + break; + } + } + if !is_modified { + bundles.push(PromptBundle { + server_name: server_name.clone(), + prompt_get: prompt_get.clone(), + }); + } + }) + .or_insert(vec![PromptBundle { + server_name: server_name.clone(), + prompt_get: prompt_get.clone(), + }]); + } + client.prompts_updated(); + } + let PromptsGetCommand { params, .. } = get_command; + let PromptBundle { prompt_get, .. } = prompts_wl + .get(&prompt_name) + .and_then(|bundles| bundles.iter().find(|b| b.server_name == server_name)) + .ok_or(GetPromptError::MissingPromptInfo)?; + // Here we need to convert the positional arguments into key value pair + // The assignment order is assumed to be the order of args as they are + // presented in PromptGet::arguments + let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, ¶ms.arguments) { + let params = schema.iter().zip(value.iter()).fold( + HashMap::::new(), + |mut acc, (prompt_get_arg, value)| { + acc.insert(prompt_get_arg.name.clone(), value.clone()); + acc + }, + ); + Some(serde_json::json!(params)) + } else { + None + }; + let params = { + let mut params = serde_json::Map::new(); + params.insert("name".to_string(), serde_json::Value::String(prompt_name)); + if let Some(args) = args { + params.insert("arguments".to_string(), args); + } + Some(serde_json::Value::Object(params)) + }; + let resp = client.request("prompts/get", params).await?; + break 'blk Ok(resp); + }, + // If we have no eligible clients this would mean one of the following: + // - The prompt does not exist, OR + // - This is the first time we have a query / our cache is out of date + // Both of which means we would have to requery + (None, _, false) => { + has_retried = true; + self.refresh_prompts(&mut prompts_wl)?; + maybe_bundles = prompts_wl.get(&prompt_name); + continue 'blk; + }, + (_, _, true) => { + break 'blk Err(GetPromptError::PromptNotFound(prompt_name)); + }, + } + } + } + + pub fn refresh_prompts(&self, prompts_wl: &mut HashMap>) -> Result<(), GetPromptError> { + *prompts_wl = self.clients.iter().fold( + HashMap::>::new(), + |mut acc, (server_name, client)| { + let prompt_gets = client.list_prompt_gets(); + let Ok(prompt_gets) = prompt_gets.read() else { + tracing::error!("Error encountered while retrieving read lock"); + return acc; + }; + for (prompt_name, prompt_get) in prompt_gets.iter() { + acc.entry(prompt_name.to_string()) + .and_modify(|bundles| { + bundles.push(PromptBundle { + server_name: server_name.to_owned(), + prompt_get: prompt_get.clone(), + }); + }) + .or_insert(vec![PromptBundle { + server_name: server_name.to_owned(), + prompt_get: prompt_get.clone(), + }]); + } + acc + }, + ); + Ok(()) + } +} + +#[inline] +fn process_tool_specs( + conversation_id: &str, + server_name: &str, + is_in_display: bool, + specs: &mut Vec, + tn_map: &mut HashMap, + regex: &Arc, +) -> Option { + // Each mcp server might have multiple tools. + // To avoid naming conflicts we are going to namespace it. + // This would also help us locate which mcp server to call the tool from. + let mut out_of_spec_tool_names = Vec::::new(); + let mut hasher = DefaultHasher::new(); + let number_of_tools = specs.len(); + // Sanitize tool names to ensure they comply with the naming requirements: + // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use + // it as is + // 2. Otherwise, remove invalid characters and handle special cases: + // - Remove namespace delimiters + // - Ensure the name starts with an alphabetic character + // - Generate a hash-based name if the sanitized result is empty + // This ensures all tool names are valid identifiers that can be safely used in the system + // If after all of the aforementioned modification the combined tool + // name we have exceeds a length of 64, we surface it as an error + for spec in specs { + let sn = if !regex.is_match(&spec.name) { + let mut sn = sanitize_name(spec.name.clone(), regex, &mut hasher); + while tn_map.contains_key(&sn) { + sn.push('1'); + } + sn + } else { + spec.name.clone() + }; + let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); + if full_name.len() > 64 { + out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); + continue; + } else if spec.description.is_empty() { + out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); + continue; + } + if sn != spec.name { + tn_map.insert( + full_name.clone(), + format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name), + ); + } + spec.name = full_name; + spec.tool_origin = ToolOrigin::McpServer(server_name.to_string()); + } + // Send server load success metric datum + let conversation_id = conversation_id.to_string(); + tokio::spawn(async move { + let event = fig_telemetry::EventType::McpServerInit { + conversation_id, + init_failure_reason: None, + number_of_tools, + }; + let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; + fig_telemetry::dispatch_or_send_event(app_event).await; + }); + // Tool name translation. This is beyond of the scope of what is + // considered a "server load". Reasoning being: + // - Failures here are not related to server load + // - There is not a whole lot we can do with this data + let loading_msg = if !out_of_spec_tool_names.is_empty() { + let msg = out_of_spec_tool_names.iter().fold( + String::from( + "The following tools are out of spec. They will be excluded from the list of available tools:\n", + ), + |mut acc, name| { + let (tool_name, msg) = match name { + OutOfSpecName::TooLong(tool_name) => ( + tool_name.as_str(), + "tool name exceeds max length of 64 when combined with server name", + ), + OutOfSpecName::IllegalChar(tool_name) => ( + tool_name.as_str(), + "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$", + ), + OutOfSpecName::EmptyDescription(tool_name) => { + (tool_name.as_str(), "tool schema contains empty description") + }, + }; + acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); + acc + }, + ); + error!( + "Server {} finished loading with the following error: \n{}", + server_name, msg + ); + if is_in_display { + Some(LoadingMsg::Error { + name: server_name.to_string(), + msg: eyre::eyre!(msg), + }) + } else { + None + } + // TODO: if no tools are valid, we need to offload the server + // from the fleet (i.e. kill the server) + } else if !tn_map.is_empty() { + let warn = tn_map.iter().fold( + String::from("The following tool names are changed:\n"), + |mut acc, (k, v)| { + acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); + acc + }, + ); + if is_in_display { + Some(LoadingMsg::Warn { + name: server_name.to_string(), + msg: eyre::eyre!(warn), + }) + } else { + None + } + } else if is_in_display { + Some(LoadingMsg::Done(server_name.to_string())) + } else { + None + }; + loading_msg +} + +fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { + if regex.is_match(&orig) && !orig.contains(NAMESPACE_DELIMITER) { + return orig; + } + let sanitized: String = orig + .chars() + .filter(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || *c == '_') + .collect::() + .replace(NAMESPACE_DELIMITER, ""); + if sanitized.is_empty() { + hasher.write(orig.as_bytes()); + let hash = format!("{:03}", hasher.finish() % 1000); + return format!("a{}", hash); + } + match sanitized.chars().next() { + Some(c) if c.is_ascii_alphabetic() => sanitized, + Some(_) => { + format!("a{}", sanitized) + }, + None => { + hasher.write(orig.as_bytes()); + format!("a{}", hasher.finish()) + }, + } +} + +fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Green), + style::Print("✓ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" loaded in "), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!("{time_taken} s\n")), + )?) +} + +fn queue_init_message( + spinner_logo_idx: usize, + complete: usize, + failed: usize, + total: usize, + output: &mut impl Write, +) -> eyre::Result<()> { + if total == complete { + queue!( + output, + style::SetForegroundColor(style::Color::Green), + style::Print("✓"), + style::ResetColor, + )?; + } else if total == complete + failed { + queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗"), + style::ResetColor, + )?; + } else { + queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?; + } + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Blue), + style::Print(format!(" {}", complete)), + style::ResetColor, + style::Print(" of "), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!("{} ", total)), + style::ResetColor, + style::Print("mcp servers initialized\n"), + )?) +} + +fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" has failed to load:\n- "), + style::Print(fail_load_msg), + style::Print("\n"), + style::Print("- run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog for detail\n"), + style::ResetColor, + )?) +} + +fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" has the following warning:\n"), + style::Print(msg), + style::ResetColor, + )?) +} + +fn queue_incomplete_load_message(msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::ResetColor, + // We expect the message start with a newline + style::Print("following servers are still loading:"), + style::Print(msg), + style::ResetColor, + )?) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sanitize_server_name() { + let regex = regex::Regex::new(VALID_TOOL_NAME).unwrap(); + let mut hasher = DefaultHasher::new(); + let orig_name = "@awslabs.cdk-mcp-server"; + let sanitized_server_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); + assert_eq!(sanitized_server_name, "awslabscdkmcpserver"); + + let orig_name = "good_name"; + let sanitized_good_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); + assert_eq!(sanitized_good_name, orig_name); + + let all_bad_name = "@@@@@"; + let sanitized_all_bad_name = sanitize_name(all_bad_name.to_string(), ®ex, &mut hasher); + assert!(regex.is_match(&sanitized_all_bad_name)); + + let with_delim = format!("a{}b{}c", NAMESPACE_DELIMITER, NAMESPACE_DELIMITER); + let sanitized = sanitize_name(with_delim, ®ex, &mut hasher); + assert_eq!(sanitized, "abc"); + } +} From 5bedc537dee1f7fb51dc533af105f48acf0b3e44 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Tue, 6 May 2025 14:27:57 -0700 Subject: [PATCH 06/26] makes necessary changes for refactor --- Cargo.lock | 52 - crates/chat-cli/Cargo.toml | 6 + .../src/cli/chat}/tool_manager/mod.rs | 0 .../chat}/tool_manager/server_messenger.rs | 13 +- .../chat/{ => tool_manager}/tool_manager.rs | 42 +- crates/chat-cli/src/lib.rs | 6 + crates/chat-cli/src/mcp_client/client.rs | 206 ++- crates/chat-cli/src/mcp_client/error.rs | 66 + .../src/mcp_client/facilitator_types.rs | 42 +- crates/chat-cli/src/mcp_client/messenger.rs | 73 + crates/chat-cli/src/mcp_client/mod.rs | 81 +- crates/chat-cli/src/mcp_client/server.rs | 24 +- .../chat-cli/src/mcp_client/transport/mod.rs | 2 +- .../src/mcp_client/transport/stdio.rs | 7 +- .../chat-cli/test_mcp_server/test_server.rs | 351 +++++ crates/q_chat/Cargo.toml | 58 - .../q_chat/src/tool_manager/tool_manager.rs | 1332 ----------------- 17 files changed, 715 insertions(+), 1646 deletions(-) rename crates/{q_chat/src => chat-cli/src/cli/chat}/tool_manager/mod.rs (100%) rename crates/{q_chat/src => chat-cli/src/cli/chat}/tool_manager/server_messenger.rs (99%) rename crates/chat-cli/src/cli/chat/{ => tool_manager}/tool_manager.rs (98%) create mode 100644 crates/chat-cli/src/lib.rs create mode 100644 crates/chat-cli/src/mcp_client/error.rs create mode 100644 crates/chat-cli/src/mcp_client/messenger.rs create mode 100644 crates/chat-cli/test_mcp_server/test_server.rs delete mode 100644 crates/q_chat/Cargo.toml delete mode 100644 crates/q_chat/src/tool_manager/tool_manager.rs diff --git a/Cargo.lock b/Cargo.lock index ecc83f3aec..e55148dc8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7028,58 +7028,6 @@ dependencies = [ ] [[package]] -<<<<<<< HEAD -======= -name = "q_chat" -version = "1.10.0" -dependencies = [ - "anstream", - "async-trait", - "aws-smithy-types", - "bstr", - "clap", - "color-print", - "convert_case 0.8.0", - "crossterm", - "eyre", - "fig_api_client", - "fig_auth", - "fig_diagnostic", - "fig_os_shim", - "fig_settings", - "fig_telemetry", - "fig_util", - "futures", - "glob", - "mcp_client", - "rand 0.9.0", - "regex", - "rustyline", - "serde", - "serde_json", - "shell-color", - "shell-words", - "shellexpand", - "shlex", - "similar", - "skim", - "spinners", - "strip-ansi-escapes", - "syntect", - "tempfile", - "thiserror 2.0.12", - "time", - "tokio", - "tracing", - "tracing-subscriber", - "unicode-width 0.2.0", - "url", - "uuid", - "winnow 0.6.22", -] - -[[package]] ->>>>>>> ca627e83 (loads tools in the background) name = "q_cli" version = "1.10.0" dependencies = [ diff --git a/crates/chat-cli/Cargo.toml b/crates/chat-cli/Cargo.toml index ddb704f6b5..e2a5a77486 100644 --- a/crates/chat-cli/Cargo.toml +++ b/crates/chat-cli/Cargo.toml @@ -14,6 +14,12 @@ workspace = true default = [] wayland = ["arboard/wayland-data-control"] +[[bin]] +name = "test_mcp_server" +path = "test_mcp_server/test_server.rs" +test = true +doc = false + [dependencies] amzn-codewhisperer-client = { path = "../amzn-codewhisperer-client" } amzn-codewhisperer-streaming-client = { path = "../amzn-codewhisperer-streaming-client" } diff --git a/crates/q_chat/src/tool_manager/mod.rs b/crates/chat-cli/src/cli/chat/tool_manager/mod.rs similarity index 100% rename from crates/q_chat/src/tool_manager/mod.rs rename to crates/chat-cli/src/cli/chat/tool_manager/mod.rs diff --git a/crates/q_chat/src/tool_manager/server_messenger.rs b/crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs similarity index 99% rename from crates/q_chat/src/tool_manager/server_messenger.rs rename to crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs index aad019bc16..dad1648d65 100644 --- a/crates/q_chat/src/tool_manager/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs @@ -1,4 +1,10 @@ -use mcp_client::{ +use tokio::sync::mpsc::{ + Receiver, + Sender, + channel, +}; + +use crate::mcp_client::{ Messenger, MessengerError, PromptsListResult, @@ -6,11 +12,6 @@ use mcp_client::{ ResourcesListResult, ToolsListResult, }; -use tokio::sync::mpsc::{ - Receiver, - Sender, - channel, -}; #[derive(Clone, Debug)] pub enum UpdateEventMessage { diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs similarity index 98% rename from crates/chat-cli/src/cli/chat/tool_manager.rs rename to crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs index bd97f3f431..8ba5af6801 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs @@ -36,35 +36,35 @@ use tracing::{ warn, }; -use crate::command::PromptsGetCommand; -use crate::message::AssistantToolUse; -use crate::tool_manager::server_messenger::{ +use crate::api_client::model::{ + ToolResult, + ToolResultContentBlock, + ToolResultStatus, +}; +use crate::cli::chat::command::PromptsGetCommand; +use crate::cli::chat::message::AssistantToolUse; +use crate::cli::chat::tool_manager::server_messenger::{ ServerMessengerBuilder, UpdateEventMessage, }; -use crate::tools::custom_tool::{ +use crate::cli::chat::tools::custom_tool::{ CustomTool, CustomToolClient, CustomToolConfig, }; -use crate::tools::execute_bash::ExecuteBash; -use crate::tools::fs_read::FsRead; -use crate::tools::fs_write::FsWrite; -use crate::tools::gh_issue::GhIssue; -use crate::tools::use_aws::UseAws; -use crate::tools::{ +use crate::cli::chat::tools::execute_bash::ExecuteBash; +use crate::cli::chat::tools::fs_read::FsRead; +use crate::cli::chat::tools::fs_write::FsWrite; +use crate::cli::chat::tools::gh_issue::GhIssue; +use crate::cli::chat::tools::thinking::Thinking; +use crate::cli::chat::tools::use_aws::UseAws; +use crate::cli::chat::tools::{ Tool, ToolOrigin, ToolSpec, }; -use crate::api_client::model::{ - ToolResult, - ToolResultContentBlock, - ToolResultStatus, -}; use crate::mcp_client::{ JsonRpcResponse, - NullMessenger, PromptGet, }; use crate::telemetry::send_mcp_server_init; @@ -580,7 +580,7 @@ impl ToolManager { let tx = self.loading_status_sender.take(); let display_task = self.loading_display_task.take(); let tool_specs = { - let tool_specs = + let mut tool_specs = serde_json::from_str::>(include_str!("../tools/tool_index.json"))?; if !crate::cli::chat::tools::thinking::Thinking::is_enabled() { tool_specs.remove("q_think_tool"); @@ -964,13 +964,7 @@ fn process_tool_specs( // Send server load success metric datum let conversation_id = conversation_id.to_string(); tokio::spawn(async move { - let event = fig_telemetry::EventType::McpServerInit { - conversation_id, - init_failure_reason: None, - number_of_tools, - }; - let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; - fig_telemetry::dispatch_or_send_event(app_event).await; + send_mcp_server_init(conversation_id, None, number_of_tools).await; }); // Tool name translation. This is beyond of the scope of what is // considered a "server load". Reasoning being: diff --git a/crates/chat-cli/src/lib.rs b/crates/chat-cli/src/lib.rs new file mode 100644 index 0000000000..2b584b4c47 --- /dev/null +++ b/crates/chat-cli/src/lib.rs @@ -0,0 +1,6 @@ +//! This lib.rs is only here for testing purposes. +//! `test_mcp_server/test_server.rs` is declared as a separate binary and would need a way to +//! reference types defined inside of this crate, hence the export. +pub mod mcp_client; + +pub use mcp_client::*; diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 9dc9f6bc98..15ce904b91 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -11,9 +11,7 @@ use std::sync::{ }; use std::time::Duration; -#[cfg(unix)] use nix::sys::signal::Signal; -#[cfg(unix)] use nix::unistd::Pid; use serde::{ Deserialize, @@ -23,31 +21,32 @@ use thiserror::Error; use tokio::time; use tokio::time::error::Elapsed; -use crate::mcp_client::transport::base_protocol::{ +use super::transport::base_protocol::{ JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcVersion, }; -use crate::mcp_client::transport::stdio::JsonRpcStdioTransport; -use crate::mcp_client::transport::{ +use super::transport::stdio::JsonRpcStdioTransport; +use super::transport::{ self, Transport, TransportError, }; -use crate::mcp_client::{ +use super::{ JsonRpcResponse, Listener as _, LogListener, + Messenger, PaginationSupportedOps, PromptGet, PromptsListResult, ResourceTemplatesListResult, ResourcesListResult, + ServerCapabilities, ToolsListResult, }; -pub type ServerCapabilities = serde_json::Value; pub type ClientInfo = serde_json::Value; pub type StdioTransport = JsonRpcStdioTransport; @@ -97,10 +96,18 @@ pub enum ClientError { source: tokio::time::error::Elapsed, context: String, }, + #[error("Unexpected msg type encountered")] + UnexpectedMsgType, #[error("{0}")] NegotiationError(String), #[error("Failed to obtain process id")] MissingProcessId, + #[error("Invalid path received")] + InvalidPath, + #[error("{0}")] + ProcessKillError(String), + #[error("{0}")] + PoisonError(String), } impl From<(tokio::time::error::Elapsed, String)> for ClientError { @@ -114,10 +121,10 @@ pub struct Client { server_name: String, transport: Arc, timeout: u64, - #[cfg(unix)] server_process_id: Option, client_info: serde_json::Value, current_id: Arc, + pub messenger: Option>, pub prompt_gets: Arc>>, pub is_prompts_out_of_date: Arc, } @@ -130,10 +137,10 @@ impl Clone for Client { timeout: self.timeout, // Note that we cannot have an id for the clone because we would kill the original // process when we drop the clone - #[cfg(unix)] server_process_id: None, client_info: self.client_info.clone(), current_id: self.current_id.clone(), + messenger: None, prompt_gets: self.prompt_gets.clone(), is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), } @@ -156,11 +163,8 @@ impl Client { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) + .process_group(0) .envs(std::env::vars()); - - #[cfg(unix)] - command.process_group(0); - if let Some(env) = env { for (env_name, env_value) in env { command.env(env_name, env_value); @@ -168,32 +172,29 @@ impl Client { } command.args(args).spawn()? }; - - #[cfg(unix)] - let server_process_id = Some(Pid::from_raw( - child - .id() - .ok_or(ClientError::MissingProcessId)? + let server_process_id = child.id().ok_or(ClientError::MissingProcessId)?; + #[allow(clippy::map_err_ignore)] + let server_process_id = Pid::from_raw( + server_process_id .try_into() - .map_err(|_err| ClientError::MissingProcessId)?, - )); - + .map_err(|_| ClientError::MissingProcessId)?, + ); + let server_process_id = Some(server_process_id); let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?); Ok(Self { server_name, transport, timeout, - #[cfg(unix)] server_process_id, client_info, current_id: Arc::new(AtomicU64::new(0)), + messenger: None, prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), }) } } -#[cfg(unix)] impl Drop for Client where T: Transport, @@ -213,8 +214,10 @@ where { /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization /// - /// Also done is the spawn of a background task that constantly listens for incoming messages - /// from the server. + /// Also done are the following: + /// - Spawns task for listening to server driven workflows + /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server + /// capabilities received pub async fn init(&self) -> Result { let transport_ref = self.transport.clone(); let server_name = self.server_name.clone(); @@ -296,63 +299,118 @@ where let client_cap = ClientCapabilities::from(self.client_info.clone()); serde_json::json!(client_cap) }); - let server_capabilities = self.request("initialize", init_params).await?; - if let Err(e) = examine_server_capabilities(&server_capabilities) { + let init_resp = self.request("initialize", init_params).await?; + if let Err(e) = examine_server_capabilities(&init_resp) { return Err(ClientError::NegotiationError(format!( "Client {} has failed to negotiate server capabilities with server: {:?}", self.server_name, e ))); } + let cap = { + let result = init_resp.result.ok_or(ClientError::NegotiationError(format!( + "Server {} init resp is missing result", + self.server_name + )))?; + let cap = result + .get("capabilities") + .ok_or(ClientError::NegotiationError(format!( + "Server {} init resp result is missing capabilities", + self.server_name + )))? + .clone(); + serde_json::from_value::(cap)? + }; self.notify("initialized", None).await?; // TODO: group this into examine_server_capabilities // Prefetch prompts in the background. We should only do this after the server has been // initialized - if let Some(res) = &server_capabilities.result { - if let Some(cap) = res.get("capabilities") { - if cap.get("prompts").is_some() { - self.is_prompts_out_of_date.store(true, Ordering::Relaxed); - let client_ref = (*self).clone(); - tokio::spawn(async move { - let Ok(resp) = client_ref.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client_ref.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); - return; - }; - let Some(prompts) = result.get("prompts") else { - tracing::warn!( - "Prompt list query result contained no field named prompts for {0}", - client_ref.server_name - ); - return; - }; - let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { - tracing::error!( - "Prompt list query deserialization failed for {0}", - client_ref.server_name - ); - return; - }; - let Ok(mut lock) = client_ref.prompt_gets.write() else { - tracing::error!( - "Failed to obtain write lock for prompt list query for {0}", - client_ref.server_name - ); - return; - }; - for prompt in prompts { - let name = prompt.name.clone(); - lock.insert(name, prompt); - } - }); + if cap.prompts.is_some() { + self.is_prompts_out_of_date.store(true, Ordering::Relaxed); + let client_ref = (*self).clone(); + tokio::spawn(async move { + let Ok(resp) = client_ref.request("prompts/list", None).await else { + tracing::error!("Prompt list query failed for {0}", client_ref.server_name); + return; + }; + let Some(result) = resp.result else { + tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); + return; + }; + let Some(prompts) = result.get("prompts") else { + tracing::warn!( + "Prompt list query result contained no field named prompts for {0}", + client_ref.server_name + ); + return; + }; + let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { + tracing::error!( + "Prompt list query deserialization failed for {0}", + client_ref.server_name + ); + return; + }; + let Ok(mut lock) = client_ref.prompt_gets.write() else { + tracing::error!( + "Failed to obtain write lock for prompt list query for {0}", + client_ref.server_name + ); + return; + }; + for prompt in prompts { + let name = prompt.name.clone(); + lock.insert(name, prompt); } - } + }); + } + if let (Some(_), Some(messenger)) = (&cap.tools, &self.messenger) { + tracing::error!( + "## background: {} is spawning background task to fetch tools", + self.server_name + ); + let client_ref = (*self).clone(); + let msger = messenger.duplicate(); + tokio::spawn(async move { + // TODO: decouple pagination logic from request and have page fetching logic here + // instead + let resp = match client_ref.request("tools/list", None).await { + Ok(resp) => resp, + Err(e) => { + tracing::error!("Failed to retrieve tool list from {}: {:?}", client_ref.server_name, e); + return; + }, + }; + if let Some(error) = resp.error { + let msg = format!( + "Failed to retrieve tool list for {}: {:?}", + client_ref.server_name, error + ); + tracing::error!("{}", &msg); + return; + } + let Some(result) = resp.result else { + tracing::error!("Tool list response from {} is missing result", client_ref.server_name); + return; + }; + let tool_list_result = match serde_json::from_value::(result) { + Ok(result) => result, + Err(e) => { + tracing::error!( + "Failed to deserialize tool result from {}: {:?}", + client_ref.server_name, + e + ); + return; + }, + }; + if let Err(e) = msger.send_tools_list_result(tool_list_result).await { + tracing::error!("Failed to send tool result through messenger {:?}", e); + } + }); } - Ok(serde_json::to_value(server_capabilities)?) + Ok(cap) } /// Sends a request to the server associated. @@ -403,13 +461,13 @@ where loop { let result = current_resp.result.as_ref().cloned().unwrap(); let mut list: Vec = match ops { - PaginationSupportedOps::Resources => { + PaginationSupportedOps::ResourcesList => { let ResourcesListResult { resources: list, .. } = serde_json::from_value::(result) .map_err(ClientError::Serialization)?; list }, - PaginationSupportedOps::ResourceTemplates => { + PaginationSupportedOps::ResourceTemplatesList => { let ResourceTemplatesListResult { resource_templates: list, .. @@ -417,13 +475,13 @@ where .map_err(ClientError::Serialization)?; list }, - PaginationSupportedOps::Prompts => { + PaginationSupportedOps::PromptsList => { let PromptsListResult { prompts: list, .. } = serde_json::from_value::(result) .map_err(ClientError::Serialization)?; list }, - PaginationSupportedOps::Tools => { + PaginationSupportedOps::ToolsList => { let ToolsListResult { tools: list, .. } = serde_json::from_value::(result) .map_err(ClientError::Serialization)?; list diff --git a/crates/chat-cli/src/mcp_client/error.rs b/crates/chat-cli/src/mcp_client/error.rs new file mode 100644 index 0000000000..d05e7efa4d --- /dev/null +++ b/crates/chat-cli/src/mcp_client/error.rs @@ -0,0 +1,66 @@ +/// Error codes as defined in the MCP protocol. +/// +/// These error codes are based on the JSON-RPC 2.0 specification with additional +/// MCP-specific error codes in the -32000 to -32099 range. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +pub enum ErrorCode { + /// Invalid JSON was received by the server. + /// An error occurred on the server while parsing the JSON text. + ParseError = -32700, + + /// The JSON sent is not a valid Request object. + InvalidRequest = -32600, + + /// The method does not exist / is not available. + MethodNotFound = -32601, + + /// Invalid method parameter(s). + InvalidParams = -32602, + + /// Internal JSON-RPC error. + InternalError = -32603, + + /// Server has not been initialized. + /// This error is returned when a request is made before the server + /// has been properly initialized. + ServerNotInitialized = -32002, + + /// Unknown error code. + /// This error is returned when an error code is received that is not + /// recognized by the implementation. + UnknownErrorCode = -32001, + + /// Request failed. + /// This error is returned when a request fails for a reason not covered + /// by other error codes. + RequestFailed = -32000, +} + +impl From for ErrorCode { + fn from(code: i32) -> Self { + match code { + -32700 => ErrorCode::ParseError, + -32600 => ErrorCode::InvalidRequest, + -32601 => ErrorCode::MethodNotFound, + -32602 => ErrorCode::InvalidParams, + -32603 => ErrorCode::InternalError, + -32002 => ErrorCode::ServerNotInitialized, + -32001 => ErrorCode::UnknownErrorCode, + -32000 => ErrorCode::RequestFailed, + _ => ErrorCode::UnknownErrorCode, + } + } +} + +impl From for i32 { + fn from(code: ErrorCode) -> Self { + code as i32 + } +} + +impl std::fmt::Display for ErrorCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} diff --git a/crates/chat-cli/src/mcp_client/facilitator_types.rs b/crates/chat-cli/src/mcp_client/facilitator_types.rs index 38d4aca280..908f555bd2 100644 --- a/crates/chat-cli/src/mcp_client/facilitator_types.rs +++ b/crates/chat-cli/src/mcp_client/facilitator_types.rs @@ -7,19 +7,19 @@ use thiserror::Error; /// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination #[derive(Debug, Clone, PartialEq, Eq)] pub enum PaginationSupportedOps { - Resources, - ResourceTemplates, - Prompts, - Tools, + ResourcesList, + ResourceTemplatesList, + PromptsList, + ToolsList, } impl PaginationSupportedOps { pub fn as_key(&self) -> &str { match self { - PaginationSupportedOps::Resources => "resources", - PaginationSupportedOps::ResourceTemplates => "resourceTemplates", - PaginationSupportedOps::Prompts => "prompts", - PaginationSupportedOps::Tools => "tools", + PaginationSupportedOps::ResourcesList => "resources", + PaginationSupportedOps::ResourceTemplatesList => "resourceTemplates", + PaginationSupportedOps::PromptsList => "prompts", + PaginationSupportedOps::ToolsList => "tools", } } } @@ -29,10 +29,10 @@ impl TryFrom<&str> for PaginationSupportedOps { fn try_from(value: &str) -> Result { match value { - "resources/list" => Ok(PaginationSupportedOps::Resources), - "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplates), - "prompts/list" => Ok(PaginationSupportedOps::Prompts), - "tools/list" => Ok(PaginationSupportedOps::Tools), + "resources/list" => Ok(PaginationSupportedOps::ResourcesList), + "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplatesList), + "prompts/list" => Ok(PaginationSupportedOps::PromptsList), + "tools/list" => Ok(PaginationSupportedOps::ToolsList), _ => Err(OpsConversionError::InvalidMethod), } } @@ -227,3 +227,21 @@ pub struct Resource { /// Resource contents pub contents: ResourceContents, } + +/// Represents the capabilities supported by a Model Context Protocol server +/// This is the "capabilities" field in the result of a response for init +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerCapabilities { + /// Configuration for server logging capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub logging: Option, + /// Configuration for prompt-related capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub prompts: Option, + /// Configuration for resource management capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub resources: Option, + /// Configuration for tool integration capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, +} diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs new file mode 100644 index 0000000000..caa6cf20e2 --- /dev/null +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -0,0 +1,73 @@ +use thiserror::Error; + +use super::{ + PromptsListResult, + ResourceTemplatesListResult, + ResourcesListResult, + ToolsListResult, +}; + +/// An interface that abstracts the implementation for information delivery from client and its +/// consumer. It is through this interface secondary information (i.e. information that are needed +/// to make requests to mcp servers) are obtained passively. Consumers of client can of course +/// choose to "actively" retrieve these information via explicitly making these requests. +#[async_trait::async_trait] +pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { + /// Sends the result of a tools list operation to the consumer + /// This function is used to deliver information about available tools + async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError>; + + /// Sends the result of a prompts list operation to the consumer + /// This function is used to deliver information about available prompts + async fn send_prompts_list_result(&self, result: PromptsListResult) -> Result<(), MessengerError>; + + /// Sends the result of a resources list operation to the consumer + /// This function is used to deliver information about available resources + async fn send_resources_list_result(&self, result: ResourcesListResult) -> Result<(), MessengerError>; + + /// Sends the result of a resource templates list operation to the consumer + /// This function is used to deliver information about available resource templates + async fn send_resource_templates_list_result( + &self, + result: ResourceTemplatesListResult, + ) -> Result<(), MessengerError>; + + /// Creates a duplicate of the messenger object + /// This function is used to create a new instance of the messenger with the same configuration + fn duplicate(&self) -> Box; +} + +#[derive(Clone, Debug, Error)] +pub enum MessengerError { + #[error("{0}")] + Custom(String), +} + +#[derive(Clone, Debug)] +pub struct NullMessenger; + +#[async_trait::async_trait] +impl Messenger for NullMessenger { + async fn send_tools_list_result(&self, _result: ToolsListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_prompts_list_result(&self, _result: PromptsListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_resources_list_result(&self, _result: ResourcesListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_resource_templates_list_result( + &self, + _result: ResourceTemplatesListResult, + ) -> Result<(), MessengerError> { + Ok(()) + } + + fn duplicate(&self) -> Box { + Box::new(NullMessenger) + } +} diff --git a/crates/chat-cli/src/mcp_client/mod.rs b/crates/chat-cli/src/mcp_client/mod.rs index 1f0298dbb5..465dcf6cec 100644 --- a/crates/chat-cli/src/mcp_client/mod.rs +++ b/crates/chat-cli/src/mcp_client/mod.rs @@ -1,77 +1,12 @@ -#![allow(dead_code)] - -mod client; -mod facilitator_types; -mod server; -mod transport; +pub mod client; +pub mod error; +pub mod facilitator_types; +pub mod messenger; +pub mod server; +pub mod transport; pub use client::*; pub use facilitator_types::*; +pub use messenger::*; +pub use server::*; pub use transport::*; - -/// Error codes as defined in the MCP protocol. -/// -/// These error codes are based on the JSON-RPC 2.0 specification with additional -/// MCP-specific error codes in the -32000 to -32099 range. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(i32)] -pub enum McpError { - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - ParseError = -32700, - - /// The JSON sent is not a valid Request object. - InvalidRequest = -32600, - - /// The method does not exist / is not available. - MethodNotFound = -32601, - - /// Invalid method parameter(s). - InvalidParams = -32602, - - /// Internal JSON-RPC error. - InternalError = -32603, - - /// Server has not been initialized. - /// This error is returned when a request is made before the server - /// has been properly initialized. - ServerNotInitialized = -32002, - - /// Unknown error code. - /// This error is returned when an error code is received that is not - /// recognized by the implementation. - UnknownErrorCode = -32001, - - /// Request failed. - /// This error is returned when a request fails for a reason not covered - /// by other error codes. - RequestFailed = -32000, -} - -impl From for McpError { - fn from(code: i32) -> Self { - match code { - -32700 => McpError::ParseError, - -32600 => McpError::InvalidRequest, - -32601 => McpError::MethodNotFound, - -32602 => McpError::InvalidParams, - -32603 => McpError::InternalError, - -32002 => McpError::ServerNotInitialized, - -32001 => McpError::UnknownErrorCode, - -32000 => McpError::RequestFailed, - _ => McpError::UnknownErrorCode, - } - } -} - -impl From for i32 { - fn from(code: McpError) -> Self { - code as i32 - } -} - -impl std::fmt::Display for McpError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} diff --git a/crates/chat-cli/src/mcp_client/server.rs b/crates/chat-cli/src/mcp_client/server.rs index d51d19125e..0b251f1ccf 100644 --- a/crates/chat-cli/src/mcp_client/server.rs +++ b/crates/chat-cli/src/mcp_client/server.rs @@ -15,24 +15,22 @@ use tokio::io::{ }; use tokio::task::JoinHandle; -use crate::mcp_client::client::StdioTransport; -use crate::mcp_client::transport::base_protocol::{ +use super::Listener as _; +use super::client::StdioTransport; +use super::error::ErrorCode; +use super::transport::base_protocol::{ JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, }; -use crate::mcp_client::transport::stdio::JsonRpcStdioTransport; -use crate::mcp_client::transport::{ +use super::transport::stdio::JsonRpcStdioTransport; +use super::transport::{ JsonRpcVersion, Transport, TransportError, }; -use crate::mcp_client::{ - Listener as _, - McpError, -}; pub type Request = serde_json::Value; pub type Response = Option; @@ -179,7 +177,7 @@ async fn process_request( jsonrpc: JsonRpcVersion::default(), id, error: Some(JsonRpcError { - code: McpError::InvalidRequest.into(), + code: ErrorCode::InvalidRequest.into(), message: "Server has already been initialized".to_owned(), data: None, }), @@ -193,7 +191,7 @@ async fn process_request( jsonrpc: JsonRpcVersion::default(), id, error: Some(JsonRpcError { - code: McpError::InvalidRequest.into(), + code: ErrorCode::InvalidRequest.into(), message: "Invalid method for initialization (use request)".to_owned(), data: None, }), @@ -218,7 +216,7 @@ async fn process_request( jsonrpc: JsonRpcVersion::default(), id, error: Some(JsonRpcError { - code: McpError::InternalError.into(), + code: ErrorCode::InternalError.into(), message: "Error producing initialization response".to_owned(), data: None, }), @@ -242,7 +240,7 @@ async fn process_request( let resp = handler.handle_incoming(method, params).await.map_or_else( |error| { let err = JsonRpcError { - code: McpError::InternalError.into(), + code: ErrorCode::InternalError.into(), message: error.to_string(), data: None, }; @@ -280,7 +278,7 @@ async fn process_request( jsonrpc: JsonRpcVersion::default(), id, error: Some(JsonRpcError { - code: McpError::ServerNotInitialized.into(), + code: ErrorCode::ServerNotInitialized.into(), message: "Server has not been initialized".to_owned(), data: None, }), diff --git a/crates/chat-cli/src/mcp_client/transport/mod.rs b/crates/chat-cli/src/mcp_client/transport/mod.rs index f86fc498f3..5796ba5323 100644 --- a/crates/chat-cli/src/mcp_client/transport/mod.rs +++ b/crates/chat-cli/src/mcp_client/transport/mod.rs @@ -4,7 +4,7 @@ pub mod stdio; use std::fmt::Debug; pub use base_protocol::*; -pub use stdio::JsonRpcStdioTransport; +pub use stdio::*; use thiserror::Error; #[derive(Clone, Debug, Error)] diff --git a/crates/chat-cli/src/mcp_client/transport/stdio.rs b/crates/chat-cli/src/mcp_client/transport/stdio.rs index 270756f2d9..2505c39ddd 100644 --- a/crates/chat-cli/src/mcp_client/transport/stdio.rs +++ b/crates/chat-cli/src/mcp_client/transport/stdio.rs @@ -204,7 +204,12 @@ mod tests { }; use tokio::process::Command; - use super::*; + use super::{ + JsonRpcMessage, + JsonRpcStdioTransport, + Listener, + Transport, + }; // Helpers for testing fn create_test_message() -> JsonRpcMessage { diff --git a/crates/chat-cli/test_mcp_server/test_server.rs b/crates/chat-cli/test_mcp_server/test_server.rs new file mode 100644 index 0000000000..6851f33922 --- /dev/null +++ b/crates/chat-cli/test_mcp_server/test_server.rs @@ -0,0 +1,351 @@ +//! This is a bin used solely for testing the client +use std::collections::HashMap; +use std::str::FromStr; +use std::sync::atomic::{ + AtomicU8, + Ordering, +}; + +use chat_cli::{ + JsonRpcRequest, + JsonRpcResponse, + JsonRpcStdioTransport, + PreServerRequestHandler, + Response, + Server, + ServerError, + ServerRequestHandler, +}; +use tokio::sync::Mutex; + +#[derive(Default)] +struct Handler { + pending_request: Option Option + Send + Sync>>, + #[allow(clippy::type_complexity)] + send_request: Option) -> Result<(), ServerError> + Send + Sync>>, + storage: Mutex>, + tool_spec: Mutex>, + tool_spec_key_list: Mutex>, + prompts: Mutex>, + prompt_key_list: Mutex>, + prompt_list_call_no: AtomicU8, +} + +impl PreServerRequestHandler for Handler { + fn register_pending_request_callback( + &mut self, + cb: impl Fn(u64) -> Option + Send + Sync + 'static, + ) { + self.pending_request = Some(Box::new(cb)); + } + + fn register_send_request_callback( + &mut self, + cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, + ) { + self.send_request = Some(Box::new(cb)); + } +} + +#[async_trait::async_trait] +impl ServerRequestHandler for Handler { + async fn handle_initialize(&self, params: Option) -> Result { + let mut storage = self.storage.lock().await; + if let Some(params) = params { + storage.insert("client_cap".to_owned(), params); + } + let capabilities = serde_json::json!({ + "protocolVersion": "2024-11-05", + "capabilities": { + "logging": {}, + "prompts": { + "listChanged": true + }, + "resources": { + "subscribe": true, + "listChanged": true + }, + "tools": { + "listChanged": true + } + }, + "serverInfo": { + "name": "TestServer", + "version": "1.0.0" + } + }); + Ok(Some(capabilities)) + } + + async fn handle_incoming(&self, method: &str, params: Option) -> Result { + match method { + "notifications/initialized" => { + { + let mut storage = self.storage.lock().await; + storage.insert( + "init_ack_sent".to_owned(), + serde_json::Value::from_str("true").expect("Failed to convert string to value"), + ); + } + Ok(None) + }, + "verify_init_params_sent" => { + let client_capabilities = { + let storage = self.storage.lock().await; + storage.get("client_cap").cloned() + }; + Ok(client_capabilities) + }, + "verify_init_ack_sent" => { + let result = { + let storage = self.storage.lock().await; + storage.get("init_ack_sent").cloned() + }; + Ok(result) + }, + "store_mock_tool_spec" => { + let Some(params) = params else { + eprintln!("Params missing from store mock tool spec"); + return Ok(None); + }; + // expecting a mock_specs: { key: String, value: serde_json::Value }[]; + let Ok(mock_specs) = serde_json::from_value::>(params) else { + eprintln!("Failed to convert to mock specs from value"); + return Ok(None); + }; + let self_tool_specs = self.tool_spec.lock().await; + let mut self_tool_spec_key_list = self.tool_spec_key_list.lock().await; + let _ = mock_specs.iter().fold(self_tool_specs, |mut acc, spec| { + let Some(key) = spec.get("key").cloned() else { + return acc; + }; + let Ok(key) = serde_json::from_value::(key) else { + eprintln!("Failed to convert serde value to string for key"); + return acc; + }; + self_tool_spec_key_list.push(key.clone()); + acc.insert(key, spec.get("value").cloned()); + acc + }); + Ok(None) + }, + "tools/list" => { + if let Some(params) = params { + if let Some(cursor) = params.get("cursor").cloned() { + let Ok(cursor) = serde_json::from_value::(cursor) else { + eprintln!("Failed to convert cursor to string: {:#?}", params); + return Ok(None); + }; + let self_tool_spec_key_list = self.tool_spec_key_list.lock().await; + let self_tool_spec = self.tool_spec.lock().await; + let (next_cursor, spec) = { + 'blk: { + for (i, item) in self_tool_spec_key_list.iter().enumerate() { + if item == &cursor { + break 'blk ( + self_tool_spec_key_list.get(i + 1).cloned(), + self_tool_spec.get(&cursor).cloned().unwrap(), + ); + } + } + (None, None) + } + }; + if let Some(next_cursor) = next_cursor { + return Ok(Some(serde_json::json!({ + "tools": [spec.unwrap()], + "nextCursor": next_cursor, + }))); + } else { + return Ok(Some(serde_json::json!({ + "tools": [spec.unwrap()], + }))); + } + } else { + eprintln!("Params exist but cursor is missing"); + return Ok(None); + } + } else { + let first_key = self + .tool_spec_key_list + .lock() + .await + .first() + .expect("First key missing from tool specs") + .clone(); + let first_value = self + .tool_spec + .lock() + .await + .get(&first_key) + .expect("First value missing from tool specs") + .clone(); + let second_key = self + .tool_spec_key_list + .lock() + .await + .get(1) + .expect("Second key missing from tool specs") + .clone(); + return Ok(Some(serde_json::json!({ + "tools": [first_value], + "nextCursor": second_key + }))); + }; + }, + "get_env_vars" => { + let kv = std::env::vars().fold(HashMap::::new(), |mut acc, (k, v)| { + acc.insert(k, v); + acc + }); + Ok(Some(serde_json::json!(kv))) + }, + // This is a test path relevant only to sampling + "trigger_server_request" => { + let Some(ref send_request) = self.send_request else { + return Err(ServerError::MissingMethod); + }; + let params = Some(serde_json::json!({ + "messages": [ + { + "role": "user", + "content": { + "type": "text", + "text": "What is the capital of France?" + } + } + ], + "modelPreferences": { + "hints": [ + { + "name": "claude-3-sonnet" + } + ], + "intelligencePriority": 0.8, + "speedPriority": 0.5 + }, + "systemPrompt": "You are a helpful assistant.", + "maxTokens": 100 + })); + send_request("sampling/createMessage", params)?; + Ok(None) + }, + "store_mock_prompts" => { + let Some(params) = params else { + eprintln!("Params missing from store mock prompts"); + return Ok(None); + }; + // expecting a mock_prompts: { key: String, value: serde_json::Value }[]; + let Ok(mock_prompts) = serde_json::from_value::>(params) else { + eprintln!("Failed to convert to mock specs from value"); + return Ok(None); + }; + let self_prompts = self.prompts.lock().await; + let mut self_prompt_key_list = self.prompt_key_list.lock().await; + let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| { + let Some(key) = spec.get("key").cloned() else { + return acc; + }; + let Ok(key) = serde_json::from_value::(key) else { + eprintln!("Failed to convert serde value to string for key"); + return acc; + }; + self_prompt_key_list.push(key.clone()); + acc.insert(key, spec.get("value").cloned()); + acc + }); + Ok(None) + }, + "prompts/list" => { + self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed); + if let Some(params) = params { + if let Some(cursor) = params.get("cursor").cloned() { + let Ok(cursor) = serde_json::from_value::(cursor) else { + eprintln!("Failed to convert cursor to string: {:#?}", params); + return Ok(None); + }; + let self_prompt_key_list = self.prompt_key_list.lock().await; + let self_prompts = self.prompts.lock().await; + let (next_cursor, spec) = { + 'blk: { + for (i, item) in self_prompt_key_list.iter().enumerate() { + if item == &cursor { + break 'blk ( + self_prompt_key_list.get(i + 1).cloned(), + self_prompts.get(&cursor).cloned().unwrap(), + ); + } + } + (None, None) + } + }; + if let Some(next_cursor) = next_cursor { + return Ok(Some(serde_json::json!({ + "prompts": [spec.unwrap()], + "nextCursor": next_cursor, + }))); + } else { + return Ok(Some(serde_json::json!({ + "prompts": [spec.unwrap()], + }))); + } + } else { + eprintln!("Params exist but cursor is missing"); + return Ok(None); + } + } else { + let first_key = self + .prompt_key_list + .lock() + .await + .first() + .expect("First key missing from prompts") + .clone(); + let first_value = self + .prompts + .lock() + .await + .get(&first_key) + .expect("First value missing from prompts") + .clone(); + let second_key = self + .prompt_key_list + .lock() + .await + .get(1) + .expect("Second key missing from prompts") + .clone(); + return Ok(Some(serde_json::json!({ + "prompts": [first_value], + "nextCursor": second_key + }))); + }; + }, + "get_prompt_list_call_no" => Ok(Some( + serde_json::to_value::(self.prompt_list_call_no.load(Ordering::Relaxed)) + .expect("Failed to convert list call no to u8"), + )), + _ => Err(ServerError::MissingMethod), + } + } + + // This is a test path relevant only to sampling + async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError> { + let JsonRpcResponse { id, .. } = resp; + let _pending = self.pending_request.as_ref().and_then(|f| f(id)); + Ok(()) + } + + async fn handle_shutdown(&self) -> Result<(), ServerError> { + Ok(()) + } +} + +#[tokio::main] +async fn main() { + let handler = Handler::default(); + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + let test_server = Server::::new(handler, stdin, stdout).expect("Failed to create server"); + let _ = test_server.init().expect("Test server failed to init").await; +} diff --git a/crates/q_chat/Cargo.toml b/crates/q_chat/Cargo.toml deleted file mode 100644 index 8dba7b3117..0000000000 --- a/crates/q_chat/Cargo.toml +++ /dev/null @@ -1,58 +0,0 @@ -[package] -name = "q_chat" -authors.workspace = true -edition.workspace = true -homepage.workspace = true -publish.workspace = true -version.workspace = true -license.workspace = true - -[dependencies] -async-trait.workspace = true -anstream.workspace = true -aws-smithy-types = "1.2.10" -bstr.workspace = true -clap.workspace = true -color-print.workspace = true -convert_case.workspace = true -crossterm.workspace = true -eyre.workspace = true -fig_api_client.workspace = true -fig_auth.workspace = true -fig_diagnostic.workspace = true -fig_os_shim.workspace = true -fig_settings.workspace = true -fig_telemetry.workspace = true -fig_util.workspace = true -futures.workspace = true -glob.workspace = true -mcp_client.workspace = true -rand.workspace = true -regex.workspace = true -rustyline = { version = "15.0.0", features = ["derive", "custom-bindings"] } -serde.workspace = true -serde_json.workspace = true -shell-color.workspace = true -shell-words = "1.1" -shellexpand.workspace = true -shlex.workspace = true -similar.workspace = true -skim = "0.16.2" -spinners.workspace = true -syntect = { version = "5.2.0", features = [ "default-syntaxes", "default-themes" ]} -tempfile.workspace = true -thiserror.workspace = true -time.workspace = true -tokio.workspace = true -tracing.workspace = true -unicode-width.workspace = true -url.workspace = true -uuid.workspace = true -winnow.workspace = true -strip-ansi-escapes = "0.2.1" - -[dev-dependencies] -tracing-subscriber.workspace = true - -[lints] -workspace = true diff --git a/crates/q_chat/src/tool_manager/tool_manager.rs b/crates/q_chat/src/tool_manager/tool_manager.rs deleted file mode 100644 index faea281d31..0000000000 --- a/crates/q_chat/src/tool_manager/tool_manager.rs +++ /dev/null @@ -1,1332 +0,0 @@ -use std::collections::HashMap; -use std::hash::{ - DefaultHasher, - Hasher, -}; -use std::io::Write; -use std::path::PathBuf; -use std::sync::atomic::AtomicBool; -use std::sync::mpsc::RecvTimeoutError; -use std::sync::{ - Arc, - RwLock as SyncRwLock, -}; - -use convert_case::Casing; -use crossterm::{ - cursor, - execute, - queue, - style, - terminal, -}; -use futures::{ - StreamExt, - stream, -}; -<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs -======== -use mcp_client::{ - JsonRpcResponse, - PromptGet, -}; -use regex::Regex; ->>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; -use tokio::sync::Mutex; -use tracing::{ - error, - warn, -}; - -<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs -use super::command::PromptsGetCommand; -use super::message::AssistantToolUse; -use super::tools::custom_tool::{ -======== -use crate::command::PromptsGetCommand; -use crate::message::AssistantToolUse; -use crate::tool_manager::server_messenger::{ - ServerMessengerBuilder, - UpdateEventMessage, -}; -use crate::tools::custom_tool::{ ->>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs - CustomTool, - CustomToolClient, - CustomToolConfig, -}; -<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs -use super::tools::execute_bash::ExecuteBash; -use super::tools::fs_read::FsRead; -use super::tools::fs_write::FsWrite; -use super::tools::gh_issue::GhIssue; -use super::tools::thinking::Thinking; -use super::tools::use_aws::UseAws; -use super::tools::{ -======== -use crate::tools::execute_bash::ExecuteBash; -use crate::tools::fs_read::FsRead; -use crate::tools::fs_write::FsWrite; -use crate::tools::gh_issue::GhIssue; -use crate::tools::use_aws::UseAws; -use crate::tools::{ ->>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs - Tool, - ToolOrigin, - ToolSpec, -}; -<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs -use crate::api_client::model::{ - ToolResult, - ToolResultContentBlock, - ToolResultStatus, -}; -use crate::mcp_client::{ - JsonRpcResponse, - NullMessenger, - PromptGet, -}; -use crate::telemetry::send_mcp_server_init; -======== ->>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs - -const NAMESPACE_DELIMITER: &str = "___"; -// This applies for both mcp server and tool name since in the end the tool name as seen by the -// model is just {server_name}{NAMESPACE_DELIMITER}{tool_name} -const VALID_TOOL_NAME: &str = "^[a-zA-Z][a-zA-Z0-9_]*$"; -const SPINNER_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; - -#[derive(Debug, Error)] -pub enum GetPromptError { - #[error("Prompt with name {0} does not exist")] - PromptNotFound(String), - #[error("Prompt {0} is offered by more than one server. Use one of the following {1}")] - AmbiguousPrompt(String, String), - #[error("Missing client")] - MissingClient, - #[error("Missing prompt name")] - MissingPromptName, - #[error("Synchronization error: {0}")] - Synchronization(String), - #[error("Missing prompt bundle")] - MissingPromptInfo, - #[error(transparent)] - General(#[from] eyre::Report), -} - -/// Messages used for communication between the tool initialization thread and the loading -/// display thread. These messages control the visual loading indicators shown to -/// the user during tool initialization. -pub enum LoadingMsg { - /// Indicates a new tool is being initialized and should be added to the loading - /// display. The String parameter is the name of the tool being initialized. - Add(String), - /// Indicates a tool has finished initializing successfully and should be removed from - /// the loading display. The String parameter is the name of the tool that - /// completed initialization. - Done(String), - /// Represents an error that occurred during tool initialization. - /// Contains the name of the server that failed to initialize and the error message. - Error { name: String, msg: eyre::Report }, - /// Represents a warning that occurred during tool initialization. - /// Contains the name of the server that generated the warning and the warning message. - Warn { name: String, msg: eyre::Report }, - /// Signals that the loading display thread should terminate. - /// This is sent when all tool initialization is complete or when the application is shutting - /// down. - Terminate, -} - -/// Represents the state of a loading indicator for a tool being initialized. -/// -/// This struct tracks timing information for each tool's loading status display in the terminal. -/// -/// # Fields -/// * `init_time` - When initialization for this tool began, used to calculate load time -struct StatusLine { - init_time: std::time::Instant, -} - -// This is to mirror claude's config set up -#[derive(Clone, Serialize, Deserialize, Debug, Default)] -#[serde(rename_all = "camelCase")] -pub struct McpServerConfig { - mcp_servers: HashMap, -} - -impl McpServerConfig { - pub async fn load_config(output: &mut impl Write) -> eyre::Result { - let mut cwd = std::env::current_dir()?; - cwd.push(".amazonq/mcp.json"); - let expanded_path = shellexpand::tilde("~/.aws/amazonq/mcp.json"); - let global_path = PathBuf::from(expanded_path.as_ref()); - let global_buf = tokio::fs::read(global_path).await.ok(); - let local_buf = tokio::fs::read(cwd).await.ok(); - let conf = match (global_buf, local_buf) { - (Some(global_buf), Some(local_buf)) => { - let mut global_conf = Self::from_slice(&global_buf, output, "global")?; - let local_conf = Self::from_slice(&local_buf, output, "local")?; - for (server_name, config) in local_conf.mcp_servers { - if global_conf.mcp_servers.insert(server_name.clone(), config).is_some() { - queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("WARNING: "), - style::ResetColor, - style::Print("MCP config conflict for "), - style::SetForegroundColor(style::Color::Green), - style::Print(server_name), - style::ResetColor, - style::Print(". Using workspace version.\n") - )?; - } - } - global_conf - }, - (None, Some(local_buf)) => Self::from_slice(&local_buf, output, "local")?, - (Some(global_buf), None) => Self::from_slice(&global_buf, output, "global")?, - _ => Default::default(), - }; - output.flush()?; - Ok(conf) - } - - fn from_slice(slice: &[u8], output: &mut impl Write, location: &str) -> eyre::Result { - match serde_json::from_slice::(slice) { - Ok(config) => Ok(config), - Err(e) => { - queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("WARNING: "), - style::ResetColor, - style::Print(format!("Error reading {location} mcp config: {e}\n")), - style::Print("Please check to make sure config is correct. Discarding.\n"), - )?; - Ok(McpServerConfig::default()) - }, - } - } -} - -#[derive(Default)] -pub struct ToolManagerBuilder { - mcp_server_config: Option, - prompt_list_sender: Option>>, - prompt_list_receiver: Option>>, - conversation_id: Option, -} - -impl ToolManagerBuilder { - pub fn mcp_server_config(mut self, config: McpServerConfig) -> Self { - self.mcp_server_config.replace(config); - self - } - - pub fn prompt_list_sender(mut self, sender: std::sync::mpsc::Sender>) -> Self { - self.prompt_list_sender.replace(sender); - self - } - - pub fn prompt_list_receiver(mut self, receiver: std::sync::mpsc::Receiver>) -> Self { - self.prompt_list_receiver.replace(receiver); - self - } - - pub fn conversation_id(mut self, conversation_id: &str) -> Self { - self.conversation_id.replace(conversation_id.to_string()); - self - } - - pub async fn build(mut self) -> eyre::Result { - let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; - debug_assert!(self.conversation_id.is_some()); - let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; - let regex = regex::Regex::new(VALID_TOOL_NAME)?; - let mut hasher = DefaultHasher::new(); - let pre_initialized = mcp_servers - .into_iter() - .map(|(server_name, server_config)| { - let snaked_cased_name = server_name.to_case(convert_case::Case::Snake); - let sanitized_server_name = sanitize_name(snaked_cased_name, ®ex, &mut hasher); - let custom_tool_client = CustomToolClient::from_config(sanitized_server_name.clone(), server_config); - (sanitized_server_name, custom_tool_client) - }) - .collect::>(); - - // Send up task to update user on server loading status - let (tx, rx) = std::sync::mpsc::channel::(); - // Using a hand rolled thread because it's just easier to do this than do deal with the Send - // requirements that comes with holding onto the stdout lock. - let loading_display_task = tokio::task::spawn_blocking(move || { - let stdout = std::io::stdout(); - let mut stdout_lock = stdout.lock(); - let mut loading_servers = HashMap::::new(); - let mut spinner_logo_idx: usize = 0; - let mut complete: usize = 0; - let mut failed: usize = 0; - loop { - match rx.recv_timeout(std::time::Duration::from_millis(50)) { - Ok(recv_result) => match recv_result { - LoadingMsg::Add(name) => { - let init_time = std::time::Instant::now(); - let status_line = StatusLine { init_time }; - execute!(stdout_lock, cursor::MoveToColumn(0))?; - if !loading_servers.is_empty() { - // TODO: account for terminal width - execute!(stdout_lock, cursor::MoveUp(1))?; - } - loading_servers.insert(name.clone(), status_line); - let total = loading_servers.len(); - execute!(stdout_lock, terminal::Clear(terminal::ClearType::CurrentLine))?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; - }, - LoadingMsg::Done(name) => { - if let Some(status_line) = loading_servers.get(&name) { - complete += 1; - let time_taken = - (std::time::Instant::now() - status_line.init_time).as_secs_f64().abs(); - let time_taken = format!("{:.2}", time_taken); - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_success_message(&name, &time_taken, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; - } - }, - LoadingMsg::Error { name, msg } => { - failed += 1; - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_failure_message(&name, &msg, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - }, - LoadingMsg::Warn { name, msg } => { - complete += 1; - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = eyre::eyre!(msg.to_string()); - queue_warn_message(&name, &msg, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; - }, - LoadingMsg::Terminate => { - if !loading_servers.is_empty() { - let msg = loading_servers.iter().fold(String::new(), |mut acc, (server_name, _)| { - acc.push_str(format!("\n - {server_name}").as_str()); - acc - }); - let msg = eyre::eyre!(msg); - queue_incomplete_load_message(&msg, &mut stdout_lock)?; - } - break; - }, - }, - Err(RecvTimeoutError::Timeout) => { - spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); - execute!( - stdout_lock, - cursor::SavePosition, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - style::Print(SPINNER_CHARS[spinner_logo_idx]), - cursor::RestorePosition - )?; - }, - _ => break, - } - } - Ok::<_, eyre::Report>(()) - }); - let mut clients = HashMap::>::new(); - let load_msg_sender = tx.clone(); - let conv_id_clone = conversation_id.clone(); - let regex = Arc::new(Regex::new(VALID_TOOL_NAME)?); - let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); - tokio::spawn(async move { - let mut is_in_display = true; - while let Some(msg) = msg_rx.recv().await { - // For now we will treat every list result as if they contain the - // complete set of tools. This is not necessarily true in the future when - // request method on the mcp client no longer buffers all the pages from - // list calls. - match msg { - UpdateEventMessage::ToolsListResult { server_name, result } => { - error!("## background: from {server_name}: {:?}", result); - let mut specs = result - .tools - .into_iter() - .filter_map(|v| serde_json::from_value::(v).ok()) - .collect::>(); - let mut sanitized_mapping = HashMap::::new(); - if let Some(load_msg) = process_tool_specs( - conv_id_clone.as_str(), - &server_name, - is_in_display, - &mut specs, - &mut sanitized_mapping, - ®ex, - ) { - if let Err(e) = load_msg_sender.send(load_msg) { - warn!( - "Error sending update message to display task: {:?}\nAssume display task has completed", - e - ); - is_in_display = false; - } - } - }, - UpdateEventMessage::PromptsListResult { server_name, result } => {}, - UpdateEventMessage::ResourcesListResult { server_name, result } => {}, - UpdateEventMessage::ResouceTemplatesListResult { server_name, result } => {}, - UpdateEventMessage::DisplayTaskEnded => { - is_in_display = false; - }, - } - } - }); - for (mut name, init_res) in pre_initialized { - let _ = tx.send(LoadingMsg::Add(name.clone())); - match init_res { - Ok(mut client) => { - let messenger = messenger_builder.build_with_name(client.get_server_name().to_owned()); - client.assign_messenger(Box::new(messenger)); - let mut client = Arc::new(client); - while let Some(collided_client) = clients.insert(name.clone(), client) { - // to avoid server name collision we are going to circumvent this by - // appending the name with 1 - name.push('1'); - client = collided_client; - } - }, - Err(e) => { - error!("Error initializing mcp client for server {}: {:?}", name, &e); - send_mcp_server_init(conversation_id.clone(), Some(e.to_string()), 0).await; - - let _ = tx.send(LoadingMsg::Error { - name: name.clone(), - msg: e, - }); - }, - } - } - let loading_display_task = Some(loading_display_task); - let loading_status_sender = Some(tx); - - // Set up task to handle prompt requests - let sender = self.prompt_list_sender.take(); - let receiver = self.prompt_list_receiver.take(); - let prompts = Arc::new(SyncRwLock::new(HashMap::default())); - // TODO: accommodate hot reload of mcp servers - if let (Some(sender), Some(receiver)) = (sender, receiver) { - let clients = clients.iter().fold(HashMap::new(), |mut acc, (n, c)| { - acc.insert(n.to_string(), Arc::downgrade(c)); - acc - }); - let prompts_clone = prompts.clone(); - tokio::task::spawn_blocking(move || { - let receiver = Arc::new(std::sync::Mutex::new(receiver)); - loop { - let search_word = receiver.lock().map_err(|e| eyre::eyre!("{:?}", e))?.recv()?; - if clients - .values() - .any(|client| client.upgrade().is_some_and(|c| c.is_prompts_out_of_date())) - { - let mut prompts_wl = prompts_clone.write().map_err(|e| { - eyre::eyre!( - "Error retrieving write lock on prompts for tab complete {}", - e.to_string() - ) - })?; - *prompts_wl = clients.iter().fold( - HashMap::>::new(), - |mut acc, (server_name, client)| { - let Some(client) = client.upgrade() else { - return acc; - }; - let prompt_gets = client.list_prompt_gets(); - let Ok(prompt_gets) = prompt_gets.read() else { - tracing::error!("Error retrieving read lock for prompt gets for tab complete"); - return acc; - }; - for (prompt_name, prompt_get) in prompt_gets.iter() { - acc.entry(prompt_name.to_string()) - .and_modify(|bundles| { - bundles.push(PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }); - }) - .or_insert(vec![PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }]); - } - client.prompts_updated(); - acc - }, - ); - } - let prompts_rl = prompts_clone.read().map_err(|e| { - eyre::eyre!( - "Error retrieving read lock on prompts for tab complete {}", - e.to_string() - ) - })?; - let filtered_prompts = prompts_rl - .iter() - .flat_map(|(prompt_name, bundles)| { - if bundles.len() > 1 { - bundles - .iter() - .map(|b| format!("{}/{}", b.server_name, prompt_name)) - .collect() - } else { - vec![prompt_name.to_owned()] - } - }) - .filter(|n| { - if let Some(p) = &search_word { - n.contains(p) - } else { - true - } - }) - .collect::>(); - if let Err(e) = sender.send(filtered_prompts) { - error!("Error sending prompts to chat helper: {:?}", e); - } - } - #[allow(unreachable_code)] - Ok::<(), eyre::Report>(()) - }); - } - - Ok(ToolManager { - conversation_id, - clients, - prompts, - loading_display_task, - loading_status_sender, - ..Default::default() - }) - } -} - -#[derive(Clone, Debug)] -/// A collection of information that is used for the following purposes: -/// - Checking if prompt info cached is out of date -/// - Retrieve new prompt info -pub struct PromptBundle { - /// The server name from which the prompt is offered / exposed - pub server_name: String, - /// The prompt get (info with which a prompt is retrieved) cached - pub prompt_get: PromptGet, -} - -/// Categorizes different types of tool name validation failures: -/// - `TooLong`: The tool name exceeds the maximum allowed length -/// - `IllegalChar`: The tool name contains characters that are not allowed -/// - `EmptyDescription`: The tool description is empty or missing -#[allow(dead_code)] -enum OutOfSpecName { - TooLong(String), - IllegalChar(String), - EmptyDescription(String), -} - -type NewToolSpecs = Arc, Vec)>>>; - -#[derive(Default)] -/// Manages the lifecycle and interactions with tools from various sources, including MCP servers. -/// This struct is responsible for initializing tools, handling tool requests, and maintaining -/// a cache of available prompts from connected servers. -pub struct ToolManager { - /// Unique identifier for the current conversation. - /// This ID is used to track and associate tools with a specific chat session. - pub conversation_id: String, - - /// Map of server names to their corresponding client instances. - /// These clients are used to communicate with MCP servers. - pub clients: HashMap>, - - pub has_new_stuff: Arc, - - new_tool_specs: NewToolSpecs, - - /// Cache for prompts collected from different servers. - /// Key: prompt name - /// Value: a list of PromptBundle that has a prompt of this name. - /// This cache helps resolve prompt requests efficiently and handles - /// cases where multiple servers offer prompts with the same name. - pub prompts: Arc>>>, - - /// Handle to the thread that displays loading status for tool initialization. - /// This thread provides visual feedback to users during the tool loading process. - loading_display_task: Option>>, - - /// Channel sender for communicating with the loading display thread. - /// Used to send status updates about tool initialization progress. - loading_status_sender: Option>, - - /// Mapping from sanitized tool names to original tool names. - /// This is used to handle tool name transformations that may occur during initialization - /// to ensure tool names comply with naming requirements. - pub tn_map: HashMap, - - /// A cache of tool's input schema for all of the available tools. - /// This is mainly used to show the user what the tools look like from the perspective of the - /// model. - pub schema: HashMap, -} - -impl ToolManager { - pub async fn load_tools(&mut self) -> eyre::Result> { - let tx = self.loading_status_sender.take(); - let display_task = self.loading_display_task.take(); - let tool_specs = { -<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs - let mut tool_specs = - serde_json::from_str::>(include_str!("tools/tool_index.json"))?; - if !crate::cli::chat::tools::thinking::Thinking::is_enabled() { - tool_specs.remove("q_think_tool"); - } -======== - let tool_specs = - serde_json::from_str::>(include_str!("../tools/tool_index.json"))?; ->>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs - Arc::new(Mutex::new(tool_specs)) - }; - let conversation_id = self.conversation_id.clone(); - let regex = Arc::new(regex::Regex::new(VALID_TOOL_NAME)?); - self.new_tool_specs = Arc::new(Mutex::new(HashMap::new())); - let load_tools = self - .clients -<<<<<<<< HEAD:crates/chat-cli/src/cli/chat/tool_manager.rs - .iter() - .map(|(server_name, client)| { - let client_clone = client.clone(); - let server_name_clone = server_name.clone(); - let tx_clone = tx.clone(); - let regex_clone = regex.clone(); - let tool_specs_clone = tool_specs.clone(); - let conversation_id = conversation_id.clone(); - async move { - let tool_spec = client_clone.init().await; - let mut sanitized_mapping = HashMap::::new(); - match tool_spec { - Ok((server_name, specs)) => { - // Each mcp server might have multiple tools. - // To avoid naming conflicts we are going to namespace it. - // This would also help us locate which mcp server to call the tool from. - let mut out_of_spec_tool_names = Vec::::new(); - let mut hasher = DefaultHasher::new(); - let number_of_tools = specs.len(); - // Sanitize tool names to ensure they comply with the naming requirements: - // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use it as is - // 2. Otherwise, remove invalid characters and handle special cases: - // - Remove namespace delimiters - // - Ensure the name starts with an alphabetic character - // - Generate a hash-based name if the sanitized result is empty - // This ensures all tool names are valid identifiers that can be safely used in the system - // If after all of the aforementioned modification the combined tool - // name we have exceeds a length of 64, we surface it as an error - for mut spec in specs { - let sn = if !regex_clone.is_match(&spec.name) { - let mut sn = sanitize_name(spec.name.clone(), ®ex_clone, &mut hasher); - while sanitized_mapping.contains_key(&sn) { - sn.push('1'); - } - sn - } else { - spec.name.clone() - }; - let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); - if full_name.len() > 64 { - out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name)); - continue; - } else if spec.description.is_empty() { - out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name)); - continue; - } - if sn != spec.name { - sanitized_mapping.insert(full_name.clone(), format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name)); - } - spec.name = full_name; - spec.tool_origin = ToolOrigin::McpServer(server_name.clone()); - tool_specs_clone.lock().await.insert(spec.name.clone(), spec); - } - - // Send server load success metric datum - send_mcp_server_init(conversation_id, None, number_of_tools).await; - - // Tool name translation. This is beyond of the scope of what is - // considered a "server load". Reasoning being: - // - Failures here are not related to server load - // - There is not a whole lot we can do with this data - if let Some(tx_clone) = &tx_clone { - let send_result = if !out_of_spec_tool_names.is_empty() { - let msg = out_of_spec_tool_names.iter().fold( - String::from("The following tools are out of spec. They will be excluded from the list of available tools:\n"), - |mut acc, name| { - let (tool_name, msg) = match name { - OutOfSpecName::TooLong(tool_name) => (tool_name.as_str(), "tool name exceeds max length of 64 when combined with server name"), - OutOfSpecName::IllegalChar(tool_name) => (tool_name.as_str(), "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$"), - OutOfSpecName::EmptyDescription(tool_name) => (tool_name.as_str(), "tool schema contains empty description"), - }; - acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); - acc - } - ); - tx_clone.send(LoadingMsg::Error { - name: server_name.clone(), - msg: eyre::eyre!(msg), - }) - // TODO: if no tools are valid, we need to offload the server - // from the fleet (i.e. kill the server) - } else if !sanitized_mapping.is_empty() { - let warn = sanitized_mapping.iter().fold(String::from("The following tool names are changed:\n"), |mut acc, (k, v)| { - acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); - acc - }); - tx_clone.send(LoadingMsg::Warn { - name: server_name.clone(), - msg: eyre::eyre!(warn), - }) - } else { - tx_clone.send(LoadingMsg::Done(server_name.clone())) - }; - if let Err(e) = send_result { - error!("Error while sending status update to display task: {:?}", e); - } - } - }, - Err(e) => { - error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); - let init_failure_reason = Some(e.to_string()); - send_mcp_server_init(conversation_id, init_failure_reason, 0).await; - if let Some(tx_clone) = &tx_clone { - if let Err(e) = tx_clone.send(LoadingMsg::Error { - name: server_name_clone, - msg: e, - }) { - error!("Error while sending status update to display task: {:?}", e); - } - } - }, - } - Ok::<_, eyre::Report>(Some(sanitized_mapping)) - } -======== - .values() - .map(|c| { - let clone = Arc::clone(c); - async move { clone.init().await } ->>>>>>>> ca627e83 (loads tools in the background):crates/q_chat/src/tool_manager/tool_manager.rs - }) - .collect::>(); - let some = stream::iter(load_tools) - .map(|async_closure| tokio::spawn(async_closure)) - .buffer_unordered(20) - .collect::>() - .await; - // let load_tool = self - // .clients - // .iter() - // .map(|(server_name, client)| { - // let client_clone = client.clone(); - // let server_name_clone = server_name.clone(); - // let tx_clone = tx.clone(); - // let regex_clone = regex.clone(); - // let tool_specs_clone = tool_specs.clone(); - // let conversation_id = conversation_id.clone(); - // async move { - // let tool_spec = client_clone.init().await; - // let mut sanitized_mapping = HashMap::::new(); - // match tool_spec { - // Ok((server_name, mut specs)) => { - // let msg = process_tool_specs( - // conversation_id.as_str(), - // &server_name, - // true, - // &mut specs, - // &mut sanitized_mapping, - // ®ex_clone, - // ); - // for spec in specs { - // tool_specs_clone.lock().await.insert(spec.name.clone(), spec); - // } - // if let (Some(msg), Some(tx)) = (msg, &tx_clone) { - // let _ = tx.send(msg); - // } - // }, - // Err(e) => { - // error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); - // let init_failure_reason = Some(e.to_string()); - // tokio::spawn(async move { - // let event = fig_telemetry::EventType::McpServerInit { - // conversation_id, - // init_failure_reason, - // number_of_tools: 0, - // }; - // let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; - // fig_telemetry::dispatch_or_send_event(app_event).await; - // }); - // if let Some(tx_clone) = &tx_clone { - // if let Err(e) = tx_clone.send(LoadingMsg::Error { - // name: server_name_clone, - // msg: e, - // }) { - // error!("Error while sending status update to display task: {:?}", e); - // } - // } - // }, - // } - // Ok::<_, eyre::Report>(Some(sanitized_mapping)) - // } - // }) - // .collect::>(); - // // TODO: do we want to introduce a timeout here? - // self.tn_map = stream::iter(load_tool) - // .map(|async_closure| tokio::task::spawn(async_closure)) - // .buffer_unordered(20) - // .collect::>() - // .await - // .into_iter() - // .filter_map(|r| r.ok()) - // .filter_map(|r| r.ok()) - // .flatten() - // .flatten() - // .collect::>(); - drop(tx); - if let Some(display_task) = display_task { - if let Err(e) = display_task.await { - error!("Error while joining status display task: {:?}", e); - } - } - let tool_specs = { - let mutex = - Arc::try_unwrap(tool_specs).map_err(|e| eyre::eyre!("Error unwrapping arc for tool specs {:?}", e))?; - mutex.into_inner() - }; - // caching the tool names for skim operations - for tool_name in tool_specs.keys() { - if !self.tn_map.contains_key(tool_name) { - self.tn_map.insert(tool_name.clone(), tool_name.clone()); - } - } - self.schema = tool_specs.clone(); - Ok(tool_specs) - } - - pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { - let map_err = |parse_error| ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "Failed to validate tool parameters: {parse_error}. The model has either suggested tool parameters which are incompatible with the existing tools, or has suggested one or more tool that does not exist in the list of known tools." - ))], - status: ToolResultStatus::Error, - }; - - Ok(match value.name.as_str() { - "fs_read" => Tool::FsRead(serde_json::from_value::(value.args).map_err(map_err)?), - "fs_write" => Tool::FsWrite(serde_json::from_value::(value.args).map_err(map_err)?), - "execute_bash" => Tool::ExecuteBash(serde_json::from_value::(value.args).map_err(map_err)?), - "use_aws" => Tool::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), - "report_issue" => Tool::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), - "q_think_tool" => Tool::Thinking(serde_json::from_value::(value.args).map_err(map_err)?), - // Note that this name is namespaced with server_name{DELIMITER}tool_name - name => { - let name = self.tn_map.get(name).map_or(name, String::as_str); - let (server_name, tool_name) = name.split_once(NAMESPACE_DELIMITER).ok_or(ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{name}\" is supplied with incorrect name" - ))], - status: ToolResultStatus::Error, - })?; - let Some(client) = self.clients.get(server_name) else { - return Err(ToolResult { - tool_use_id: value.id, - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{server_name}\" is not supported by the client" - ))], - status: ToolResultStatus::Error, - }); - }; - // The tool input schema has the shape of { type, properties }. - // The field "params" expected by MCP is { name, arguments }, where name is the - // name of the tool being invoked, - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools. - // The field "arguments" is where ToolUse::args belong. - let mut params = serde_json::Map::::new(); - params.insert("name".to_owned(), serde_json::Value::String(tool_name.to_owned())); - params.insert("arguments".to_owned(), value.args); - let params = serde_json::Value::Object(params); - let custom_tool = CustomTool { - name: tool_name.to_owned(), - client: client.clone(), - method: "tools/call".to_owned(), - params: Some(params), - }; - Tool::Custom(custom_tool) - }, - }) - } - - #[allow(clippy::await_holding_lock)] - pub async fn get_prompt(&self, get_command: PromptsGetCommand) -> Result { - let (server_name, prompt_name) = match get_command.params.name.split_once('/') { - None => (None::, Some(get_command.params.name.clone())), - Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), - }; - let prompt_name = prompt_name.ok_or(GetPromptError::MissingPromptName)?; - // We need to use a sync lock here because this lock is also used in a blocking thread, - // necessitated by the fact that said thread is also responsible for using a sync channel, - // which is itself necessitated by the fact that consumer of said channel is calling from a - // sync function - let mut prompts_wl = self - .prompts - .write() - .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; - let mut maybe_bundles = prompts_wl.get(&prompt_name); - let mut has_retried = false; - 'blk: loop { - match (maybe_bundles, server_name.as_ref(), has_retried) { - // If we have more than one eligible clients but no server name specified - (Some(bundles), None, _) if bundles.len() > 1 => { - break 'blk Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { - bundles.iter().fold("\n".to_string(), |mut acc, b| { - acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); - acc - }) - })); - }, - // Normal case where we have enough info to proceed - // Note that if bundle exists, it should never be empty - (Some(bundles), sn, _) => { - let bundle = if bundles.len() > 1 { - let Some(server_name) = sn else { - maybe_bundles = None; - continue 'blk; - }; - let bundle = bundles.iter().find(|b| b.server_name == *server_name); - match bundle { - Some(bundle) => bundle, - None => { - maybe_bundles = None; - continue 'blk; - }, - } - } else { - bundles.first().ok_or(GetPromptError::MissingPromptInfo)? - }; - let server_name = bundle.server_name.clone(); - let client = self.clients.get(&server_name).ok_or(GetPromptError::MissingClient)?; - // Here we lazily update the out of date cache - if client.is_prompts_out_of_date() { - let prompt_gets = client.list_prompt_gets(); - let prompt_gets = prompt_gets - .read() - .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; - for (prompt_name, prompt_get) in prompt_gets.iter() { - prompts_wl - .entry(prompt_name.to_string()) - .and_modify(|bundles| { - let mut is_modified = false; - for bundle in &mut *bundles { - let mut updated_bundle = PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }; - if bundle.server_name == *server_name { - std::mem::swap(bundle, &mut updated_bundle); - is_modified = true; - break; - } - } - if !is_modified { - bundles.push(PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }); - } - }) - .or_insert(vec![PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }]); - } - client.prompts_updated(); - } - let PromptsGetCommand { params, .. } = get_command; - let PromptBundle { prompt_get, .. } = prompts_wl - .get(&prompt_name) - .and_then(|bundles| bundles.iter().find(|b| b.server_name == server_name)) - .ok_or(GetPromptError::MissingPromptInfo)?; - // Here we need to convert the positional arguments into key value pair - // The assignment order is assumed to be the order of args as they are - // presented in PromptGet::arguments - let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, ¶ms.arguments) { - let params = schema.iter().zip(value.iter()).fold( - HashMap::::new(), - |mut acc, (prompt_get_arg, value)| { - acc.insert(prompt_get_arg.name.clone(), value.clone()); - acc - }, - ); - Some(serde_json::json!(params)) - } else { - None - }; - let params = { - let mut params = serde_json::Map::new(); - params.insert("name".to_string(), serde_json::Value::String(prompt_name)); - if let Some(args) = args { - params.insert("arguments".to_string(), args); - } - Some(serde_json::Value::Object(params)) - }; - let resp = client.request("prompts/get", params).await?; - break 'blk Ok(resp); - }, - // If we have no eligible clients this would mean one of the following: - // - The prompt does not exist, OR - // - This is the first time we have a query / our cache is out of date - // Both of which means we would have to requery - (None, _, false) => { - has_retried = true; - self.refresh_prompts(&mut prompts_wl)?; - maybe_bundles = prompts_wl.get(&prompt_name); - continue 'blk; - }, - (_, _, true) => { - break 'blk Err(GetPromptError::PromptNotFound(prompt_name)); - }, - } - } - } - - pub fn refresh_prompts(&self, prompts_wl: &mut HashMap>) -> Result<(), GetPromptError> { - *prompts_wl = self.clients.iter().fold( - HashMap::>::new(), - |mut acc, (server_name, client)| { - let prompt_gets = client.list_prompt_gets(); - let Ok(prompt_gets) = prompt_gets.read() else { - tracing::error!("Error encountered while retrieving read lock"); - return acc; - }; - for (prompt_name, prompt_get) in prompt_gets.iter() { - acc.entry(prompt_name.to_string()) - .and_modify(|bundles| { - bundles.push(PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }); - }) - .or_insert(vec![PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }]); - } - acc - }, - ); - Ok(()) - } -} - -#[inline] -fn process_tool_specs( - conversation_id: &str, - server_name: &str, - is_in_display: bool, - specs: &mut Vec, - tn_map: &mut HashMap, - regex: &Arc, -) -> Option { - // Each mcp server might have multiple tools. - // To avoid naming conflicts we are going to namespace it. - // This would also help us locate which mcp server to call the tool from. - let mut out_of_spec_tool_names = Vec::::new(); - let mut hasher = DefaultHasher::new(); - let number_of_tools = specs.len(); - // Sanitize tool names to ensure they comply with the naming requirements: - // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use - // it as is - // 2. Otherwise, remove invalid characters and handle special cases: - // - Remove namespace delimiters - // - Ensure the name starts with an alphabetic character - // - Generate a hash-based name if the sanitized result is empty - // This ensures all tool names are valid identifiers that can be safely used in the system - // If after all of the aforementioned modification the combined tool - // name we have exceeds a length of 64, we surface it as an error - for spec in specs { - let sn = if !regex.is_match(&spec.name) { - let mut sn = sanitize_name(spec.name.clone(), regex, &mut hasher); - while tn_map.contains_key(&sn) { - sn.push('1'); - } - sn - } else { - spec.name.clone() - }; - let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); - if full_name.len() > 64 { - out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); - continue; - } else if spec.description.is_empty() { - out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); - continue; - } - if sn != spec.name { - tn_map.insert( - full_name.clone(), - format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name), - ); - } - spec.name = full_name; - spec.tool_origin = ToolOrigin::McpServer(server_name.to_string()); - } - // Send server load success metric datum - let conversation_id = conversation_id.to_string(); - tokio::spawn(async move { - let event = fig_telemetry::EventType::McpServerInit { - conversation_id, - init_failure_reason: None, - number_of_tools, - }; - let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; - fig_telemetry::dispatch_or_send_event(app_event).await; - }); - // Tool name translation. This is beyond of the scope of what is - // considered a "server load". Reasoning being: - // - Failures here are not related to server load - // - There is not a whole lot we can do with this data - let loading_msg = if !out_of_spec_tool_names.is_empty() { - let msg = out_of_spec_tool_names.iter().fold( - String::from( - "The following tools are out of spec. They will be excluded from the list of available tools:\n", - ), - |mut acc, name| { - let (tool_name, msg) = match name { - OutOfSpecName::TooLong(tool_name) => ( - tool_name.as_str(), - "tool name exceeds max length of 64 when combined with server name", - ), - OutOfSpecName::IllegalChar(tool_name) => ( - tool_name.as_str(), - "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$", - ), - OutOfSpecName::EmptyDescription(tool_name) => { - (tool_name.as_str(), "tool schema contains empty description") - }, - }; - acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); - acc - }, - ); - error!( - "Server {} finished loading with the following error: \n{}", - server_name, msg - ); - if is_in_display { - Some(LoadingMsg::Error { - name: server_name.to_string(), - msg: eyre::eyre!(msg), - }) - } else { - None - } - // TODO: if no tools are valid, we need to offload the server - // from the fleet (i.e. kill the server) - } else if !tn_map.is_empty() { - let warn = tn_map.iter().fold( - String::from("The following tool names are changed:\n"), - |mut acc, (k, v)| { - acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); - acc - }, - ); - if is_in_display { - Some(LoadingMsg::Warn { - name: server_name.to_string(), - msg: eyre::eyre!(warn), - }) - } else { - None - } - } else if is_in_display { - Some(LoadingMsg::Done(server_name.to_string())) - } else { - None - }; - loading_msg -} - -fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { - if regex.is_match(&orig) && !orig.contains(NAMESPACE_DELIMITER) { - return orig; - } - let sanitized: String = orig - .chars() - .filter(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || *c == '_') - .collect::() - .replace(NAMESPACE_DELIMITER, ""); - if sanitized.is_empty() { - hasher.write(orig.as_bytes()); - let hash = format!("{:03}", hasher.finish() % 1000); - return format!("a{}", hash); - } - match sanitized.chars().next() { - Some(c) if c.is_ascii_alphabetic() => sanitized, - Some(_) => { - format!("a{}", sanitized) - }, - None => { - hasher.write(orig.as_bytes()); - format!("a{}", hasher.finish()) - }, - } -} - -fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Green), - style::Print("✓ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" loaded in "), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{time_taken} s\n")), - )?) -} - -fn queue_init_message( - spinner_logo_idx: usize, - complete: usize, - failed: usize, - total: usize, - output: &mut impl Write, -) -> eyre::Result<()> { - if total == complete { - queue!( - output, - style::SetForegroundColor(style::Color::Green), - style::Print("✓"), - style::ResetColor, - )?; - } else if total == complete + failed { - queue!( - output, - style::SetForegroundColor(style::Color::Red), - style::Print("✗"), - style::ResetColor, - )?; - } else { - queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?; - } - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Blue), - style::Print(format!(" {}", complete)), - style::ResetColor, - style::Print(" of "), - style::SetForegroundColor(style::Color::Blue), - style::Print(format!("{} ", total)), - style::ResetColor, - style::Print("mcp servers initialized\n"), - )?) -} - -fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Red), - style::Print("✗ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" has failed to load:\n- "), - style::Print(fail_load_msg), - style::Print("\n"), - style::Print("- run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog for detail\n"), - style::ResetColor, - )?) -} - -fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("⚠ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" has the following warning:\n"), - style::Print(msg), - style::ResetColor, - )?) -} - -fn queue_incomplete_load_message(msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("⚠ "), - style::ResetColor, - // We expect the message start with a newline - style::Print("following servers are still loading:"), - style::Print(msg), - style::ResetColor, - )?) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sanitize_server_name() { - let regex = regex::Regex::new(VALID_TOOL_NAME).unwrap(); - let mut hasher = DefaultHasher::new(); - let orig_name = "@awslabs.cdk-mcp-server"; - let sanitized_server_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); - assert_eq!(sanitized_server_name, "awslabscdkmcpserver"); - - let orig_name = "good_name"; - let sanitized_good_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); - assert_eq!(sanitized_good_name, orig_name); - - let all_bad_name = "@@@@@"; - let sanitized_all_bad_name = sanitize_name(all_bad_name.to_string(), ®ex, &mut hasher); - assert!(regex.is_match(&sanitized_all_bad_name)); - - let with_delim = format!("a{}b{}c", NAMESPACE_DELIMITER, NAMESPACE_DELIMITER); - let sanitized = sanitize_name(with_delim, ®ex, &mut hasher); - assert_eq!(sanitized, "abc"); - } -} From 46ef665e3ec07abcd0034354ed612cb932d61d3b Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Tue, 6 May 2025 14:32:15 -0700 Subject: [PATCH 07/26] removes mcp client crate --- Cargo.lock | 15 - Cargo.toml | 1 - .../cli/chat/tool_manager/server_messenger.rs | 4 +- .../src/cli/chat/tool_manager/tool_manager.rs | 2 +- crates/mcp_client/Cargo.toml | 30 - crates/mcp_client/src/client.rs | 833 ------------------ crates/mcp_client/src/error.rs | 66 -- crates/mcp_client/src/facilitator_types.rs | 247 ------ crates/mcp_client/src/lib.rs | 11 - crates/mcp_client/src/messenger.rs | 73 -- crates/mcp_client/src/server.rs | 293 ------ .../mcp_client/src/transport/base_protocol.rs | 108 --- crates/mcp_client/src/transport/mod.rs | 56 -- crates/mcp_client/src/transport/stdio.rs | 277 ------ crates/mcp_client/src/transport/websocket.rs | 0 .../mcp_client/test_mcp_server/test_server.rs | 354 -------- crates/q_cli/Cargo.toml | 1 - 17 files changed, 3 insertions(+), 2368 deletions(-) delete mode 100644 crates/mcp_client/Cargo.toml delete mode 100644 crates/mcp_client/src/client.rs delete mode 100644 crates/mcp_client/src/error.rs delete mode 100644 crates/mcp_client/src/facilitator_types.rs delete mode 100644 crates/mcp_client/src/lib.rs delete mode 100644 crates/mcp_client/src/messenger.rs delete mode 100644 crates/mcp_client/src/server.rs delete mode 100644 crates/mcp_client/src/transport/base_protocol.rs delete mode 100644 crates/mcp_client/src/transport/mod.rs delete mode 100644 crates/mcp_client/src/transport/stdio.rs delete mode 100644 crates/mcp_client/src/transport/websocket.rs delete mode 100644 crates/mcp_client/test_mcp_server/test_server.rs diff --git a/Cargo.lock b/Cargo.lock index e55148dc8b..7cbacdf4fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5273,20 +5273,6 @@ dependencies = [ "rayon", ] -[[package]] -name = "mcp_client" -version = "1.10.0" -dependencies = [ - "async-trait", - "nix 0.29.0", - "serde", - "serde_json", - "thiserror 2.0.12", - "tokio", - "tracing", - "uuid", -] - [[package]] name = "memchr" version = "2.7.4" @@ -7077,7 +7063,6 @@ dependencies = [ "indoc", "insta", "macos-utils", - "mcp_client", "mimalloc", "nix 0.29.0", "objc2 0.5.2", diff --git a/Cargo.toml b/Cargo.toml index 90ae9c21b2..1557c5a4f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,7 +87,6 @@ indicatif = "0.17.11" indoc = "2.0.6" insta = "1.43.1" libc = "0.2.172" -mcp_client = { path = "crates/mcp_client" } mimalloc = "0.1.46" nix = { version = "0.29.0", features = [ "feature", diff --git a/crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs b/crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs index dad1648d65..cdca50d8f7 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs @@ -27,7 +27,7 @@ pub enum UpdateEventMessage { server_name: String, result: ResourcesListResult, }, - ResouceTemplatesListResult { + ResourceTemplatesListResult { server_name: String, result: ResourceTemplatesListResult, }, @@ -103,7 +103,7 @@ impl Messenger for ServerMessenger { ) -> Result<(), MessengerError> { Ok(self .update_event_sender - .send(UpdateEventMessage::ResouceTemplatesListResult { + .send(UpdateEventMessage::ResourceTemplatesListResult { server_name: self.server_name.clone(), result, }) diff --git a/crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs index 8ba5af6801..cafb4977c1 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs @@ -373,7 +373,7 @@ impl ToolManagerBuilder { }, UpdateEventMessage::PromptsListResult { server_name, result } => {}, UpdateEventMessage::ResourcesListResult { server_name, result } => {}, - UpdateEventMessage::ResouceTemplatesListResult { server_name, result } => {}, + UpdateEventMessage::ResourceTemplatesListResult { server_name, result } => {}, UpdateEventMessage::DisplayTaskEnded => { is_in_display = false; }, diff --git a/crates/mcp_client/Cargo.toml b/crates/mcp_client/Cargo.toml deleted file mode 100644 index a95692d102..0000000000 --- a/crates/mcp_client/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -[package] -name = "mcp_client" -authors.workspace = true -edition.workspace = true -homepage.workspace = true -publish.workspace = true -version.workspace = true -license.workspace = true - -[lints] -workspace = true - -[features] -default = [] - -[[bin]] -name = "test_mcp_server" -path = "test_mcp_server/test_server.rs" -test = true -doc = false - -[dependencies] -tokio.workspace = true -serde.workspace = true -serde_json.workspace = true -async-trait.workspace = true -tracing.workspace = true -thiserror.workspace = true -uuid.workspace = true -nix.workspace = true diff --git a/crates/mcp_client/src/client.rs b/crates/mcp_client/src/client.rs deleted file mode 100644 index eae73de53a..0000000000 --- a/crates/mcp_client/src/client.rs +++ /dev/null @@ -1,833 +0,0 @@ -use std::collections::HashMap; -use std::process::Stdio; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - RwLock as SyncRwLock, -}; -use std::time::Duration; - -use nix::sys::signal::Signal; -use nix::unistd::Pid; -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; -use tokio::time; -use tokio::time::error::Elapsed; - -use crate::transport::base_protocol::{ - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcVersion, -}; -use crate::transport::stdio::JsonRpcStdioTransport; -use crate::transport::{ - self, - Transport, - TransportError, -}; -use crate::{ - JsonRpcResponse, - Listener as _, - LogListener, - Messenger, - PaginationSupportedOps, - PromptGet, - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ServerCapabilities, - ToolsListResult, -}; - -pub type ClientInfo = serde_json::Value; -pub type StdioTransport = JsonRpcStdioTransport; - -/// Represents the capabilities of a client in the Model Context Protocol. -/// This structure is sent to the server during initialization to communicate -/// what features the client supports and provide information about the client. -/// When features are added to the client, these should be declared in the [From] trait implemented -/// for the struct. -#[derive(Default, Debug, Serialize)] -#[serde(rename_all = "camelCase")] -struct ClientCapabilities { - protocol_version: JsonRpcVersion, - capabilities: HashMap, - client_info: serde_json::Value, -} - -impl From for ClientCapabilities { - fn from(client_info: ClientInfo) -> Self { - ClientCapabilities { - client_info, - ..Default::default() - } - } -} - -#[derive(Debug, Deserialize)] -pub struct ClientConfig { - pub server_name: String, - pub bin_path: String, - pub args: Vec, - pub timeout: u64, - pub client_info: serde_json::Value, - pub env: Option>, -} - -#[derive(Debug, Error)] -pub enum ClientError { - #[error(transparent)] - TransportError(#[from] TransportError), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Operation timed out: {context}")] - RuntimeError { - #[source] - source: tokio::time::error::Elapsed, - context: String, - }, - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error("Failed to obtain process id")] - MissingProcessId, - #[error("Invalid path received")] - InvalidPath, - #[error("{0}")] - ProcessKillError(String), - #[error("{0}")] - PoisonError(String), -} - -impl From<(tokio::time::error::Elapsed, String)> for ClientError { - fn from((error, context): (tokio::time::error::Elapsed, String)) -> Self { - ClientError::RuntimeError { source: error, context } - } -} - -#[derive(Debug)] -pub struct Client { - server_name: String, - transport: Arc, - timeout: u64, - server_process_id: Option, - client_info: serde_json::Value, - current_id: Arc, - pub messenger: Option>, - pub prompt_gets: Arc>>, - pub is_prompts_out_of_date: Arc, -} - -impl Clone for Client { - fn clone(&self) -> Self { - Self { - server_name: self.server_name.clone(), - transport: self.transport.clone(), - timeout: self.timeout, - // Note that we cannot have an id for the clone because we would kill the original - // process when we drop the clone - server_process_id: None, - client_info: self.client_info.clone(), - current_id: self.current_id.clone(), - messenger: None, - prompt_gets: self.prompt_gets.clone(), - is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), - } - } -} - -impl Client { - pub fn from_config(config: ClientConfig) -> Result { - let ClientConfig { - server_name, - bin_path, - args, - timeout, - client_info, - env, - } = config; - let child = { - let mut command = tokio::process::Command::new(bin_path); - command - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .process_group(0) - .envs(std::env::vars()); - if let Some(env) = env { - for (env_name, env_value) in env { - command.env(env_name, env_value); - } - } - command.args(args).spawn()? - }; - let server_process_id = child.id().ok_or(ClientError::MissingProcessId)?; - #[allow(clippy::map_err_ignore)] - let server_process_id = Pid::from_raw( - server_process_id - .try_into() - .map_err(|_| ClientError::MissingProcessId)?, - ); - let server_process_id = Some(server_process_id); - let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?); - Ok(Self { - server_name, - transport, - timeout, - server_process_id, - client_info, - current_id: Arc::new(AtomicU64::new(0)), - messenger: None, - prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), - is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), - }) - } -} - -impl Drop for Client -where - T: Transport, -{ - // IF the servers are implemented well, they will shutdown once the pipe closes. - // This drop trait is here as a fail safe to ensure we don't leave behind any orphans. - fn drop(&mut self) { - if let Some(process_id) = self.server_process_id { - let _ = nix::sys::signal::kill(process_id, Signal::SIGTERM); - } - } -} - -impl Client -where - T: Transport, -{ - /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization - /// - /// Also done are the following: - /// - Spawns task for listening to server driven workflows - /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server - /// capabilities received - pub async fn init(&self) -> Result { - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - - tokio::spawn(async move { - let mut listener = transport_ref.get_listener(); - loop { - match listener.recv().await { - Ok(msg) => { - match msg { - JsonRpcMessage::Request(_req) => {}, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { method, params, .. } = notif; - if method.as_str() == "notifications/message" || method.as_str() == "message" { - let level = params - .as_ref() - .and_then(|p| p.get("level")) - .and_then(|v| serde_json::to_string(v).ok()); - let data = params - .as_ref() - .and_then(|p| p.get("data")) - .and_then(|v| serde_json::to_string(v).ok()); - if let (Some(level), Some(data)) = (level, data) { - match level.to_lowercase().as_str() { - "error" => { - tracing::error!(target: "mcp", "{}: {}", server_name, data); - }, - "warn" => { - tracing::warn!(target: "mcp", "{}: {}", server_name, data); - }, - "info" => { - tracing::info!(target: "mcp", "{}: {}", server_name, data); - }, - "debug" => { - tracing::debug!(target: "mcp", "{}: {}", server_name, data); - }, - "trace" => { - tracing::trace!(target: "mcp", "{}: {}", server_name, data); - }, - _ => {}, - } - } - } - }, - JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ - }, - } - }, - Err(e) => { - tracing::error!("Background listening thread for client {}: {:?}", server_name, e); - }, - } - } - }); - - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - - // Spawning a task to listen and log stderr output - tokio::spawn(async move { - let mut log_listener = transport_ref.get_log_listener(); - loop { - match log_listener.recv().await { - Ok(msg) => { - tracing::trace!(target: "mcp", "{server_name} logged {}", msg); - }, - Err(e) => { - tracing::error!( - "Error encountered while reading from stderr for {server_name}: {:?}\nEnding stderr listening task.", - e - ); - break; - }, - } - } - }); - - let init_params = Some({ - let client_cap = ClientCapabilities::from(self.client_info.clone()); - serde_json::json!(client_cap) - }); - let init_resp = self.request("initialize", init_params).await?; - if let Err(e) = examine_server_capabilities(&init_resp) { - return Err(ClientError::NegotiationError(format!( - "Client {} has failed to negotiate server capabilities with server: {:?}", - self.server_name, e - ))); - } - let cap = { - let result = init_resp.result.ok_or(ClientError::NegotiationError(format!( - "Server {} init resp is missing result", - self.server_name - )))?; - let cap = result - .get("capabilities") - .ok_or(ClientError::NegotiationError(format!( - "Server {} init resp result is missing capabilities", - self.server_name - )))? - .clone(); - serde_json::from_value::(cap)? - }; - self.notify("initialized", None).await?; - - // TODO: group this into examine_server_capabilities - // Prefetch prompts in the background. We should only do this after the server has been - // initialized - if cap.prompts.is_some() { - self.is_prompts_out_of_date.store(true, Ordering::Relaxed); - let client_ref = (*self).clone(); - tokio::spawn(async move { - let Ok(resp) = client_ref.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client_ref.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); - return; - }; - let Some(prompts) = result.get("prompts") else { - tracing::warn!( - "Prompt list query result contained no field named prompts for {0}", - client_ref.server_name - ); - return; - }; - let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { - tracing::error!( - "Prompt list query deserialization failed for {0}", - client_ref.server_name - ); - return; - }; - let Ok(mut lock) = client_ref.prompt_gets.write() else { - tracing::error!( - "Failed to obtain write lock for prompt list query for {0}", - client_ref.server_name - ); - return; - }; - for prompt in prompts { - let name = prompt.name.clone(); - lock.insert(name, prompt); - } - }); - } - if let (Some(_), Some(messenger)) = (&cap.tools, &self.messenger) { - tracing::error!( - "## background: {} is spawning background task to fetch tools", - self.server_name - ); - let client_ref = (*self).clone(); - let msger = messenger.duplicate(); - tokio::spawn(async move { - // TODO: decouple pagination logic from request and have page fetching logic here - // instead - let resp = match client_ref.request("tools/list", None).await { - Ok(resp) => resp, - Err(e) => { - tracing::error!("Failed to retrieve tool list from {}: {:?}", client_ref.server_name, e); - return; - }, - }; - if let Some(error) = resp.error { - let msg = format!( - "Failed to retrieve tool list for {}: {:?}", - client_ref.server_name, error - ); - tracing::error!("{}", &msg); - return; - } - let Some(result) = resp.result else { - tracing::error!("Tool list response from {} is missing result", client_ref.server_name); - return; - }; - let tool_list_result = match serde_json::from_value::(result) { - Ok(result) => result, - Err(e) => { - tracing::error!( - "Failed to deserialize tool result from {}: {:?}", - client_ref.server_name, - e - ); - return; - }, - }; - if let Err(e) = msger.send_tools_list_result(tool_list_result).await { - tracing::error!("Failed to send tool result through messenger {:?}", e); - } - }); - } - - Ok(cap) - } - - /// Sends a request to the server associated. - /// This call will yield until a response is received. - pub async fn request( - &self, - method: &str, - params: Option, - ) -> Result { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let recv_map_err = |e: Elapsed| (e, format!("recv for {method}")); - let mut id = self.get_id(); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - tracing::trace!(target: "mcp", "To {}:\n{:#?}", self.server_name, request); - let msg = JsonRpcMessage::Request(request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let mut listener = self.transport.get_listener(); - let mut resp = time::timeout(Duration::from_millis(self.timeout), async { - // we want to ignore all other messages sent by the server at this point and let the - // background loop handle them - loop { - if let JsonRpcMessage::Response(resp) = listener.recv().await? { - if resp.id == id { - break Ok::(resp); - } - } - } - }) - .await - .map_err(recv_map_err)??; - // Pagination support: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#pagination-model - let mut next_cursor = resp.result.as_ref().and_then(|v| v.get("nextCursor")); - if next_cursor.is_some() { - let mut current_resp = resp.clone(); - let mut results = Vec::::new(); - let pagination_supported_ops = { - let maybe_pagination_supported_op: Result = method.try_into(); - maybe_pagination_supported_op.ok() - }; - if let Some(ops) = pagination_supported_ops { - loop { - let result = current_resp.result.as_ref().cloned().unwrap(); - let mut list: Vec = match ops { - PaginationSupportedOps::ResourcesList => { - let ResourcesListResult { resources: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ResourceTemplatesList => { - let ResourceTemplatesListResult { - resource_templates: list, - .. - } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::PromptsList => { - let PromptsListResult { prompts: list, .. } = - serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - PaginationSupportedOps::ToolsList => { - let ToolsListResult { tools: list, .. } = serde_json::from_value::(result) - .map_err(ClientError::Serialization)?; - list - }, - }; - results.append(&mut list); - if next_cursor.is_none() { - break; - } - id = self.get_id(); - let next_request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params: Some(serde_json::json!({ - "cursor": next_cursor, - })), - }; - let msg = JsonRpcMessage::Request(next_request); - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??; - let resp = time::timeout(Duration::from_millis(self.timeout), async { - // we want to ignore all other messages sent by the server at this point and let the - // background loop handle them - loop { - if let JsonRpcMessage::Response(resp) = listener.recv().await? { - if resp.id == id { - break Ok::(resp); - } - } - } - }) - .await - .map_err(recv_map_err)??; - current_resp = resp; - next_cursor = current_resp.result.as_ref().and_then(|v| v.get("nextCursor")); - } - resp.result = Some({ - let mut map = serde_json::Map::new(); - map.insert(ops.as_key().to_owned(), serde_json::to_value(results)?); - serde_json::to_value(map)? - }); - } - } - tracing::trace!(target: "mcp", "From {}:\n{:#?}", self.server_name, resp); - Ok(resp) - } - - /// Sends a notification to the server associated. - /// Notifications are requests that expect no responses. - pub async fn notify(&self, method: &str, params: Option) -> Result<(), ClientError> { - let send_map_err = |e: Elapsed| (e, method.to_string()); - let notification = JsonRpcNotification { - jsonrpc: JsonRpcVersion::default(), - method: format!("notifications/{}", method), - params, - }; - let msg = JsonRpcMessage::Notification(notification); - Ok( - time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) - .await - .map_err(send_map_err)??, - ) - } - - pub async fn shutdown(&self) -> Result<(), ClientError> { - Ok(self.transport.shutdown().await?) - } - - fn get_id(&self) -> u64 { - self.current_id.fetch_add(1, Ordering::SeqCst) - } -} - -fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientError> { - // Check the jrpc version. - // Currently we are only proceeding if the versions are EXACTLY the same. - let jrpc_version = ser_cap.jsonrpc.as_u32_vec(); - let client_jrpc_version = JsonRpcVersion::default().as_u32_vec(); - for (sv, cv) in jrpc_version.iter().zip(client_jrpc_version.iter()) { - if sv != cv { - return Err(ClientError::NegotiationError( - "Incompatible jrpc version between server and client".to_owned(), - )); - } - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use serde_json::Value; - - use super::*; - const TEST_BIN_OUT_DIR: &str = "target/debug"; - const TEST_SERVER_NAME: &str = "test_mcp_server"; - - fn get_workspace_root() -> PathBuf { - let output = std::process::Command::new("cargo") - .args(["metadata", "--format-version=1", "--no-deps"]) - .output() - .expect("Failed to execute cargo metadata"); - - let metadata: serde_json::Value = - serde_json::from_slice(&output.stdout).expect("Failed to parse cargo metadata"); - - let workspace_root = metadata["workspace_root"] - .as_str() - .expect("Failed to find workspace_root in metadata"); - - PathBuf::from(workspace_root) - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_client_stdio() { - std::process::Command::new("cargo") - .args(["build", "--bin", TEST_SERVER_NAME]) - .status() - .expect("Failed to build binary"); - let workspace_root = get_workspace_root(); - let bin_path = workspace_root.join(TEST_BIN_OUT_DIR).join(TEST_SERVER_NAME); - println!("bin path: {}", bin_path.to_str().unwrap_or("no path found")); - - // Testing 2 concurrent sessions to make sure transport layer does not overlap. - let client_info_one = serde_json::json!({ - "name": "TestClientOne", - "version": "1.0.0" - }); - let client_config_one = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["1".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_one.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let client_info_two = serde_json::json!({ - "name": "TestClientTwo", - "version": "1.0.0" - }); - let client_config_two = ClientConfig { - server_name: "test_tool".to_owned(), - bin_path: bin_path.to_str().unwrap().to_string(), - args: ["2".to_owned()].to_vec(), - timeout: 120 * 1000, - client_info: client_info_two.clone(), - env: { - let mut map = HashMap::::new(); - map.insert("ENV_ONE".to_owned(), "1".to_owned()); - map.insert("ENV_TWO".to_owned(), "2".to_owned()); - Some(map) - }, - }; - let mut client_one = Client::::from_config(client_config_one).expect("Failed to create client"); - let mut client_two = Client::::from_config(client_config_two).expect("Failed to create client"); - let client_one_cap = ClientCapabilities::from(client_info_one); - let client_two_cap = ClientCapabilities::from(client_info_two); - - let (res_one, res_two) = tokio::join!( - time::timeout( - time::Duration::from_secs(5), - test_client_routine(&mut client_one, serde_json::json!(client_one_cap)) - ), - time::timeout( - time::Duration::from_secs(5), - test_client_routine(&mut client_two, serde_json::json!(client_two_cap)) - ) - ); - let res_one = res_one.expect("Client one timed out"); - let res_two = res_two.expect("Client two timed out"); - assert!(res_one.is_ok()); - assert!(res_two.is_ok()); - } - - async fn test_client_routine( - client: &mut Client, - cap_sent: serde_json::Value, - ) -> Result<(), Box> { - // Test init - let _ = client.init().await.expect("Client init failed"); - tokio::time::sleep(time::Duration::from_millis(1500)).await; - let client_capabilities_sent = client - .request("verify_init_ack_sent", None) - .await - .expect("Verify init ack mock request failed"); - let has_server_recvd_init_ack = client_capabilities_sent - .result - .expect("Failed to retrieve client capabilities sent."); - assert_eq!(has_server_recvd_init_ack.to_string(), "true"); - let cap_recvd = client - .request("verify_init_params_sent", None) - .await - .expect("Verify init params mock request failed"); - let cap_recvd = cap_recvd - .result - .expect("Verify init params mock request does not contain required field (result)"); - assert!(are_json_values_equal(&cap_sent, &cap_recvd)); - - // test list tools - let fake_tool_names = ["get_weather_one", "get_weather_two", "get_weather_three"]; - let mock_result_spec = fake_tool_names.map(create_fake_tool_spec); - let mock_tool_specs_for_verify = serde_json::json!(mock_result_spec.clone()); - let mock_tool_specs_prep_param = mock_result_spec - .iter() - .zip(fake_tool_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_tool_specs_prep_param = - serde_json::to_value(mock_tool_specs_prep_param).expect("Failed to create mock tool specs prep param"); - let _ = client - .request("store_mock_tool_spec", Some(mock_tool_specs_prep_param)) - .await - .expect("Mock tool spec prep failed"); - let tool_spec_recvd = client.request("tools/list", None).await.expect("List tools failed"); - assert!(are_json_values_equal( - tool_spec_recvd - .result - .as_ref() - .and_then(|v| v.get("tools")) - .expect("Failed to retrieve tool specs from result received"), - &mock_tool_specs_for_verify - )); - - // Test list prompts directly - let fake_prompt_names = ["code_review_one", "code_review_two", "code_review_three"]; - let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); - let mock_prompts_for_verify = serde_json::json!(mock_result_prompts.clone()); - let mock_prompts_prep_param = mock_result_prompts - .iter() - .zip(fake_prompt_names.iter()) - .map(|(v, n)| { - serde_json::json!({ - "key": (*n).to_string(), - "value": v - }) - }) - .collect::>(); - let mock_prompts_prep_param = - serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); - let _ = client - .request("store_mock_prompts", Some(mock_prompts_prep_param)) - .await - .expect("Mock prompt prep failed"); - let prompts_recvd = client.request("prompts/list", None).await.expect("List prompts failed"); - assert!(are_json_values_equal( - prompts_recvd - .result - .as_ref() - .and_then(|v| v.get("prompts")) - .expect("Failed to retrieve prompts from results received"), - &mock_prompts_for_verify - )); - - // Test env var inclusion - let env_vars = client.request("get_env_vars", None).await.expect("Get env vars failed"); - let env_one = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_ONE") - .expect("Failed to retrieve env one from env var request"); - let env_two = env_vars - .result - .as_ref() - .expect("Failed to retrieve results from env var request") - .get("ENV_TWO") - .expect("Failed to retrieve env two from env var request"); - let env_one_as_str = serde_json::to_string(env_one).expect("Failed to convert env one to string"); - let env_two_as_str = serde_json::to_string(env_two).expect("Failed to convert env two to string"); - assert_eq!(env_one_as_str, "\"1\"".to_string()); - assert_eq!(env_two_as_str, "\"2\"".to_string()); - - let shutdown_result = client.shutdown().await; - assert!(shutdown_result.is_ok()); - Ok(()) - } - - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } - - fn create_fake_tool_spec(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Get current weather information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - }) - } - - fn create_fake_prompts(name: &str) -> serde_json::Value { - serde_json::json!({ - "name": name, - "description": "Asks the LLM to analyze code quality and suggest improvements", - "arguments": [ - { - "name": "code", - "description": "The code to review", - "required": true - } - ] - }) - } -} diff --git a/crates/mcp_client/src/error.rs b/crates/mcp_client/src/error.rs deleted file mode 100644 index d05e7efa4d..0000000000 --- a/crates/mcp_client/src/error.rs +++ /dev/null @@ -1,66 +0,0 @@ -/// Error codes as defined in the MCP protocol. -/// -/// These error codes are based on the JSON-RPC 2.0 specification with additional -/// MCP-specific error codes in the -32000 to -32099 range. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(i32)] -pub enum ErrorCode { - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - ParseError = -32700, - - /// The JSON sent is not a valid Request object. - InvalidRequest = -32600, - - /// The method does not exist / is not available. - MethodNotFound = -32601, - - /// Invalid method parameter(s). - InvalidParams = -32602, - - /// Internal JSON-RPC error. - InternalError = -32603, - - /// Server has not been initialized. - /// This error is returned when a request is made before the server - /// has been properly initialized. - ServerNotInitialized = -32002, - - /// Unknown error code. - /// This error is returned when an error code is received that is not - /// recognized by the implementation. - UnknownErrorCode = -32001, - - /// Request failed. - /// This error is returned when a request fails for a reason not covered - /// by other error codes. - RequestFailed = -32000, -} - -impl From for ErrorCode { - fn from(code: i32) -> Self { - match code { - -32700 => ErrorCode::ParseError, - -32600 => ErrorCode::InvalidRequest, - -32601 => ErrorCode::MethodNotFound, - -32602 => ErrorCode::InvalidParams, - -32603 => ErrorCode::InternalError, - -32002 => ErrorCode::ServerNotInitialized, - -32001 => ErrorCode::UnknownErrorCode, - -32000 => ErrorCode::RequestFailed, - _ => ErrorCode::UnknownErrorCode, - } - } -} - -impl From for i32 { - fn from(code: ErrorCode) -> Self { - code as i32 - } -} - -impl std::fmt::Display for ErrorCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} diff --git a/crates/mcp_client/src/facilitator_types.rs b/crates/mcp_client/src/facilitator_types.rs deleted file mode 100644 index 908f555bd2..0000000000 --- a/crates/mcp_client/src/facilitator_types.rs +++ /dev/null @@ -1,247 +0,0 @@ -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; - -/// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PaginationSupportedOps { - ResourcesList, - ResourceTemplatesList, - PromptsList, - ToolsList, -} - -impl PaginationSupportedOps { - pub fn as_key(&self) -> &str { - match self { - PaginationSupportedOps::ResourcesList => "resources", - PaginationSupportedOps::ResourceTemplatesList => "resourceTemplates", - PaginationSupportedOps::PromptsList => "prompts", - PaginationSupportedOps::ToolsList => "tools", - } - } -} - -impl TryFrom<&str> for PaginationSupportedOps { - type Error = OpsConversionError; - - fn try_from(value: &str) -> Result { - match value { - "resources/list" => Ok(PaginationSupportedOps::ResourcesList), - "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplatesList), - "prompts/list" => Ok(PaginationSupportedOps::PromptsList), - "tools/list" => Ok(PaginationSupportedOps::ToolsList), - _ => Err(OpsConversionError::InvalidMethod), - } - } -} - -#[derive(Error, Debug)] -pub enum OpsConversionError { - #[error("Invalid method encountered")] - InvalidMethod, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "camelCase")] -/// Role assumed for a particular message -pub enum Role { - User, - Assistant, -} - -impl std::fmt::Display for Role { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "user"), - Role::Assistant => write!(f, "assistant"), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing resources operation -pub struct ResourcesListResult { - /// List of resources - pub resources: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -/// Result of listing resource templates operation -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ResourceTemplatesListResult { - /// List of resource templates - pub resource_templates: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of prompt listing query -pub struct PromptsListResult { - /// List of prompts - pub prompts: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents an argument to be supplied to a [PromptGet] -pub struct PromptGetArg { - /// The name identifier of the prompt - pub name: String, - /// Optional description providing context about the prompt - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Indicates whether a response to this prompt is required - /// If not specified, defaults to false - #[serde(skip_serializing_if = "Option::is_none")] - pub required: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Represents a request to get a prompt from a mcp server -pub struct PromptGet { - /// Unique identifier for the prompt - pub name: String, - /// Optional description providing context about the prompt's purpose - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Optional list of arguments that define the structure of information to be collected - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// `result` field in [JsonRpcResponse] from a `prompts/get` request -pub struct PromptGetResult { - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub messages: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Completed prompt from `prompts/get` to be returned by a mcp server -pub struct Prompt { - pub role: Role, - pub content: MessageContent, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -/// Result of listing tools operation -pub struct ToolsListResult { - /// List of tools - pub tools: Vec, - /// Optional cursor for pagination - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolCallResult { - pub content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub is_error: Option, -} - -/// Content of a message -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum MessageContent { - /// Text content - Text { - /// The text content - text: String, - }, - /// Image content - #[serde(rename_all = "camelCase")] - Image { - /// base64-encoded-data - data: String, - mime_type: String, - }, - /// Resource content - Resource { - /// The resource - resource: Resource, - }, -} - -impl From for String { - fn from(val: MessageContent) -> Self { - match val { - MessageContent::Text { text } => text, - MessageContent::Image { data, mime_type } => serde_json::json!({ - "data": data, - "mime_type": mime_type - }) - .to_string(), - MessageContent::Resource { resource } => serde_json::json!(resource).to_string(), - } - } -} - -impl std::fmt::Display for MessageContent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - MessageContent::Text { text } => write!(f, "{}", text), - MessageContent::Image { data: _, mime_type } => write!(f, "Image [base64-encoded-string] ({})", mime_type), - MessageContent::Resource { resource } => write!(f, "Resource: {} ({})", resource.title, resource.uri), - } - } -} - -/// Resource contents -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum ResourceContents { - Text { text: String }, - Blob { data: Vec }, -} - -/// A resource in the system -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Resource { - /// Unique identifier for the resource - pub uri: String, - /// Human-readable title - pub title: String, - /// Optional description - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// Resource contents - pub contents: ResourceContents, -} - -/// Represents the capabilities supported by a Model Context Protocol server -/// This is the "capabilities" field in the result of a response for init -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerCapabilities { - /// Configuration for server logging capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub logging: Option, - /// Configuration for prompt-related capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub prompts: Option, - /// Configuration for resource management capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub resources: Option, - /// Configuration for tool integration capabilities - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option, -} diff --git a/crates/mcp_client/src/lib.rs b/crates/mcp_client/src/lib.rs deleted file mode 100644 index 19f23b809a..0000000000 --- a/crates/mcp_client/src/lib.rs +++ /dev/null @@ -1,11 +0,0 @@ -pub mod client; -pub mod error; -pub mod facilitator_types; -pub mod messenger; -pub mod server; -pub mod transport; - -pub use client::*; -pub use facilitator_types::*; -pub use messenger::*; -pub use transport::*; diff --git a/crates/mcp_client/src/messenger.rs b/crates/mcp_client/src/messenger.rs deleted file mode 100644 index 14e519a6b2..0000000000 --- a/crates/mcp_client/src/messenger.rs +++ /dev/null @@ -1,73 +0,0 @@ -use thiserror::Error; - -use crate::{ - PromptsListResult, - ResourceTemplatesListResult, - ResourcesListResult, - ToolsListResult, -}; - -/// An interface that abstracts the implementation for information delivery from client and its -/// consumer. It is through this interface secondary information (i.e. information that are needed -/// to make requests to mcp servers) are obtained passively. Consumers of client can of course -/// choose to "actively" retrieve these information via explicitly making these requests. -#[async_trait::async_trait] -pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { - /// Sends the result of a tools list operation to the consumer - /// This function is used to deliver information about available tools - async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError>; - - /// Sends the result of a prompts list operation to the consumer - /// This function is used to deliver information about available prompts - async fn send_prompts_list_result(&self, result: PromptsListResult) -> Result<(), MessengerError>; - - /// Sends the result of a resources list operation to the consumer - /// This function is used to deliver information about available resources - async fn send_resources_list_result(&self, result: ResourcesListResult) -> Result<(), MessengerError>; - - /// Sends the result of a resource templates list operation to the consumer - /// This function is used to deliver information about available resource templates - async fn send_resource_templates_list_result( - &self, - result: ResourceTemplatesListResult, - ) -> Result<(), MessengerError>; - - /// Creates a duplicate of the messenger object - /// This function is used to create a new instance of the messenger with the same configuration - fn duplicate(&self) -> Box; -} - -#[derive(Clone, Debug, Error)] -pub enum MessengerError { - #[error("{0}")] - Custom(String), -} - -#[derive(Clone, Debug)] -pub struct NullMessenger; - -#[async_trait::async_trait] -impl Messenger for NullMessenger { - async fn send_tools_list_result(&self, _result: ToolsListResult) -> Result<(), MessengerError> { - Ok(()) - } - - async fn send_prompts_list_result(&self, _result: PromptsListResult) -> Result<(), MessengerError> { - Ok(()) - } - - async fn send_resources_list_result(&self, _result: ResourcesListResult) -> Result<(), MessengerError> { - Ok(()) - } - - async fn send_resource_templates_list_result( - &self, - _result: ResourceTemplatesListResult, - ) -> Result<(), MessengerError> { - Ok(()) - } - - fn duplicate(&self) -> Box { - Box::new(NullMessenger) - } -} diff --git a/crates/mcp_client/src/server.rs b/crates/mcp_client/src/server.rs deleted file mode 100644 index 1ba92b154d..0000000000 --- a/crates/mcp_client/src/server.rs +++ /dev/null @@ -1,293 +0,0 @@ -use std::collections::HashMap; -use std::sync::atomic::{ - AtomicBool, - AtomicU64, - Ordering, -}; -use std::sync::{ - Arc, - Mutex, -}; - -use tokio::io::{ - Stdin, - Stdout, -}; -use tokio::task::JoinHandle; - -use crate::Listener as _; -use crate::client::StdioTransport; -use crate::error::ErrorCode; -use crate::transport::base_protocol::{ - JsonRpcError, - JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, - JsonRpcResponse, -}; -use crate::transport::stdio::JsonRpcStdioTransport; -use crate::transport::{ - JsonRpcVersion, - Transport, - TransportError, -}; - -pub type Request = serde_json::Value; -pub type Response = Option; -pub type InitializedServer = JoinHandle>; - -pub trait PreServerRequestHandler { - fn register_pending_request_callback(&mut self, cb: impl Fn(u64) -> Option + Send + Sync + 'static); - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ); -} - -#[async_trait::async_trait] -pub trait ServerRequestHandler: PreServerRequestHandler + Send + Sync + 'static { - async fn handle_initialize(&self, params: Option) -> Result; - async fn handle_incoming(&self, method: &str, params: Option) -> Result; - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError>; - async fn handle_shutdown(&self) -> Result<(), ServerError>; -} - -pub struct Server { - transport: Option>, - handler: Option, - #[allow(dead_code)] - pending_requests: Arc>>, - #[allow(dead_code)] - current_id: Arc, -} - -#[derive(Debug, thiserror::Error)] -pub enum ServerError { - #[error(transparent)] - TransportError(#[from] TransportError), - #[error(transparent)] - Io(#[from] std::io::Error), - #[error(transparent)] - Serialization(#[from] serde_json::Error), - #[error("Unexpected msg type encountered")] - UnexpectedMsgType, - #[error("{0}")] - NegotiationError(String), - #[error(transparent)] - TokioJoinError(#[from] tokio::task::JoinError), - #[error("Failed to obtain mutex lock")] - MutexError, - #[error("Failed to obtain request method")] - MissingMethod, - #[error("Failed to obtain request id")] - MissingId, - #[error("Failed to initialize server. Missing transport")] - MissingTransport, - #[error("Failed to initialize server. Missing handler")] - MissingHandler, -} - -impl Server -where - H: ServerRequestHandler, -{ - pub fn new(mut handler: H, stdin: Stdin, stdout: Stdout) -> Result { - let transport = Arc::new(JsonRpcStdioTransport::server(stdin, stdout)?); - let pending_requests = Arc::new(Mutex::new(HashMap::::new())); - let pending_requests_clone_one = pending_requests.clone(); - let current_id = Arc::new(AtomicU64::new(0)); - let pending_request_getter = move |id: u64| -> Option { - match pending_requests_clone_one.lock() { - Ok(mut p) => p.remove(&id), - Err(_) => None, - } - }; - handler.register_pending_request_callback(pending_request_getter); - let transport_clone = transport.clone(); - let pending_request_clone_two = pending_requests.clone(); - let current_id_clone = current_id.clone(); - let request_sender = move |method: &str, params: Option| -> Result<(), ServerError> { - let id = current_id_clone.fetch_add(1, Ordering::SeqCst); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, - }; - let msg = JsonRpcMessage::Request(request.clone()); - let transport = transport_clone.clone(); - tokio::task::spawn(async move { - let _ = transport.send(&msg).await; - }); - #[allow(clippy::map_err_ignore)] - let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; - pending_request.insert(id, request); - Ok(()) - }; - handler.register_send_request_callback(request_sender); - let server = Self { - transport: Some(transport), - handler: Some(handler), - pending_requests, - current_id, - }; - Ok(server) - } -} - -impl Server -where - T: Transport, - H: ServerRequestHandler, -{ - pub fn init(mut self) -> Result { - let transport = self.transport.take().ok_or(ServerError::MissingTransport)?; - let handler = Arc::new(self.handler.take().ok_or(ServerError::MissingHandler)?); - let has_initialized = Arc::new(AtomicBool::new(false)); - let listener = tokio::spawn(async move { - let mut listener = transport.get_listener(); - loop { - let request = listener.recv().await; - let transport_clone = transport.clone(); - let has_init_clone = has_initialized.clone(); - let handler_clone = handler.clone(); - tokio::task::spawn(async move { - process_request(has_init_clone, transport_clone, handler_clone, request).await; - }); - } - }); - Ok(listener) - } -} - -async fn process_request( - has_initialized: Arc, - transport: Arc, - handler: Arc, - request: Result, -) where - T: Transport, - H: ServerRequestHandler, -{ - match request { - Ok(msg) if msg.is_initialize() => { - let id = msg.id().unwrap_or_default(); - if has_initialized.load(Ordering::SeqCst) { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Server has already been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - } - let JsonRpcMessage::Request(req) = msg else { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InvalidRequest.into(), - message: "Invalid method for initialization (use request)".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - return; - }; - let JsonRpcRequest { params, .. } = req; - match handler.handle_initialize(params).await { - Ok(result) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - id, - result, - ..Default::default() - }); - let _ = transport.send(&resp).await; - has_initialized.store(true, Ordering::SeqCst); - }, - Err(_e) => { - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::InternalError.into(), - message: "Error producing initialization response".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - } - }, - Ok(msg) if msg.is_shutdown() => { - // TODO: add shutdown routine - }, - Ok(msg) if has_initialized.load(Ordering::SeqCst) => match msg { - JsonRpcMessage::Request(req) => { - let JsonRpcRequest { - id, - jsonrpc, - params, - ref method, - } = req; - let resp = handler.handle_incoming(method, params).await.map_or_else( - |error| { - let err = JsonRpcError { - code: ErrorCode::InternalError.into(), - message: error.to_string(), - data: None, - }; - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result: None, - error: Some(err), - }; - JsonRpcMessage::Response(resp) - }, - |result| { - let resp = JsonRpcResponse { - jsonrpc: jsonrpc.clone(), - id, - result, - error: None, - }; - JsonRpcMessage::Response(resp) - }, - ); - let _ = transport.send(&resp).await; - }, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { ref method, params, .. } = notif; - let _ = handler.handle_incoming(method, params).await; - }, - JsonRpcMessage::Response(resp) => { - let _ = handler.handle_response(resp).await; - }, - }, - Ok(msg) => { - let id = msg.id().unwrap_or_default(); - let resp = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: JsonRpcVersion::default(), - id, - error: Some(JsonRpcError { - code: ErrorCode::ServerNotInitialized.into(), - message: "Server has not been initialized".to_owned(), - data: None, - }), - ..Default::default() - }); - let _ = transport.send(&resp).await; - }, - Err(_e) => { - // TODO: error handling - }, - } -} diff --git a/crates/mcp_client/src/transport/base_protocol.rs b/crates/mcp_client/src/transport/base_protocol.rs deleted file mode 100644 index b0394e6e0c..0000000000 --- a/crates/mcp_client/src/transport/base_protocol.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Referencing https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/messages/ -//! Protocol Revision 2024-11-05 -use serde::{ - Deserialize, - Serialize, -}; - -pub type RequestId = u64; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct JsonRpcVersion(String); - -impl Default for JsonRpcVersion { - fn default() -> Self { - JsonRpcVersion("2.0".to_owned()) - } -} - -impl JsonRpcVersion { - pub fn as_u32_vec(&self) -> Vec { - self.0 - .split(".") - .map(|n| n.parse::().unwrap()) - .collect::>() - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(untagged)] -#[serde(deny_unknown_fields)] -// DO NOT change the order of these variants. This body of json is [untagged](https://serde.rs/enum-representations.html#untagged) -// The categorization of the deserialization depends on the order in which the variants are -// declared. -pub enum JsonRpcMessage { - Response(JsonRpcResponse), - Notification(JsonRpcNotification), - Request(JsonRpcRequest), -} - -impl JsonRpcMessage { - pub fn is_initialize(&self) -> bool { - match self { - JsonRpcMessage::Request(req) => req.method == "initialize", - _ => false, - } - } - - pub fn is_shutdown(&self) -> bool { - match self { - JsonRpcMessage::Notification(notif) => notif.method == "notification/shutdown", - _ => false, - } - } - - pub fn id(&self) -> Option { - match self { - JsonRpcMessage::Request(req) => Some(req.id), - JsonRpcMessage::Response(resp) => Some(resp.id), - JsonRpcMessage::Notification(_) => None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcRequest { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcResponse { - pub jsonrpc: JsonRpcVersion, - pub id: RequestId, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcNotification { - pub jsonrpc: JsonRpcVersion, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -#[serde(default, deny_unknown_fields)] -pub struct JsonRpcError { - pub code: i32, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] -pub enum TransportType { - #[default] - Stdio, - Websocket, -} diff --git a/crates/mcp_client/src/transport/mod.rs b/crates/mcp_client/src/transport/mod.rs deleted file mode 100644 index 5796ba5323..0000000000 --- a/crates/mcp_client/src/transport/mod.rs +++ /dev/null @@ -1,56 +0,0 @@ -pub mod base_protocol; -pub mod stdio; - -use std::fmt::Debug; - -pub use base_protocol::*; -pub use stdio::*; -use thiserror::Error; - -#[derive(Clone, Debug, Error)] -pub enum TransportError { - #[error("Serialization error: {0}")] - Serialization(String), - #[error("IO error: {0}")] - Stdio(String), - #[error("{0}")] - Custom(String), - #[error(transparent)] - RecvError(#[from] tokio::sync::broadcast::error::RecvError), -} - -impl From for TransportError { - fn from(err: serde_json::Error) -> Self { - TransportError::Serialization(err.to_string()) - } -} - -impl From for TransportError { - fn from(err: std::io::Error) -> Self { - TransportError::Stdio(err.to_string()) - } -} - -#[async_trait::async_trait] -pub trait Transport: Send + Sync + Debug + 'static { - /// Sends a message over the transport layer. - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError>; - /// Listens to awaits for a response. This is a call that should be used after `send` is called - /// to listen for a response from the message recipient. - fn get_listener(&self) -> impl Listener; - /// Gracefully terminates the transport connection, cleaning up any resources. - /// This should be called when the transport is no longer needed to ensure proper cleanup. - async fn shutdown(&self) -> Result<(), TransportError>; - /// Listener that listens for logging messages. - fn get_log_listener(&self) -> impl LogListener; -} - -#[async_trait::async_trait] -pub trait Listener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} - -#[async_trait::async_trait] -pub trait LogListener: Send + Sync + 'static { - async fn recv(&mut self) -> Result; -} diff --git a/crates/mcp_client/src/transport/stdio.rs b/crates/mcp_client/src/transport/stdio.rs deleted file mode 100644 index ab4c6a2a07..0000000000 --- a/crates/mcp_client/src/transport/stdio.rs +++ /dev/null @@ -1,277 +0,0 @@ -use std::sync::Arc; - -use tokio::io::{ - AsyncBufReadExt, - AsyncRead, - AsyncWriteExt as _, - BufReader, - Stdin, - Stdout, -}; -use tokio::process::{ - Child, - ChildStdin, -}; -use tokio::sync::{ - Mutex, - broadcast, -}; - -use super::base_protocol::JsonRpcMessage; -use super::{ - Listener, - LogListener, - Transport, - TransportError, -}; - -#[derive(Debug)] -pub enum JsonRpcStdioTransport { - Client { - stdin: Arc>, - receiver: broadcast::Receiver>, - log_receiver: broadcast::Receiver, - }, - Server { - stdout: Arc>, - receiver: broadcast::Receiver>, - }, -} - -impl JsonRpcStdioTransport { - fn spawn_reader( - reader: R, - tx: broadcast::Sender>, - ) { - tokio::spawn(async move { - let mut buffer = Vec::::new(); - let mut buf_reader = BufReader::new(reader); - loop { - buffer.clear(); - // Messages are delimited by newlines and assumed to contain no embedded newlines - // See https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio - match buf_reader.read_until(b'\n', &mut buffer).await { - Ok(0) => continue, - Ok(_) => match serde_json::from_slice::(buffer.as_slice()) { - Ok(msg) => { - let _ = tx.send(Ok(msg)); - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - }, - Err(e) => { - let _ = tx.send(Err(e.into())); - }, - } - } - }); - } - - pub fn client(child_process: Child) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - let Some(stdout) = child_process.stdout else { - return Err(TransportError::Custom("No stdout found on child process".to_owned())); - }; - let Some(stdin) = child_process.stdin else { - return Err(TransportError::Custom("No stdin found on child process".to_owned())); - }; - let Some(stderr) = child_process.stderr else { - return Err(TransportError::Custom("No stderr found on child process".to_owned())); - }; - let (log_tx, log_receiver) = broadcast::channel::(100); - tokio::task::spawn(async move { - let stderr = tokio::io::BufReader::new(stderr); - let mut lines = stderr.lines(); - while let Ok(Some(line)) = lines.next_line().await { - let _ = log_tx.send(line); - } - }); - let stdin = Arc::new(Mutex::new(stdin)); - Self::spawn_reader(stdout, tx); - Ok(JsonRpcStdioTransport::Client { - stdin, - receiver, - log_receiver, - }) - } - - pub fn server(stdin: Stdin, stdout: Stdout) -> Result { - let (tx, receiver) = broadcast::channel::>(100); - Self::spawn_reader(stdin, tx); - let stdout = Arc::new(Mutex::new(stdout)); - Ok(JsonRpcStdioTransport::Server { stdout, receiver }) - } -} - -#[async_trait::async_trait] -impl Transport for JsonRpcStdioTransport { - async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdin = stdin.lock().await; - stdin - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - stdin - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; - Ok(()) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut serialized = serde_json::to_vec(msg)?; - serialized.push(b'\n'); - let mut stdout = stdout.lock().await; - stdout - .write_all(&serialized) - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - stdout - .flush() - .await - .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; - Ok(()) - }, - } - } - - fn get_listener(&self) -> impl Listener { - match self { - JsonRpcStdioTransport::Client { receiver, .. } | JsonRpcStdioTransport::Server { receiver, .. } => { - StdioListener { - receiver: receiver.resubscribe(), - } - }, - } - } - - async fn shutdown(&self) -> Result<(), TransportError> { - match self { - JsonRpcStdioTransport::Client { stdin, .. } => { - let mut stdin = stdin.lock().await; - Ok(stdin.shutdown().await?) - }, - JsonRpcStdioTransport::Server { stdout, .. } => { - let mut stdout = stdout.lock().await; - Ok(stdout.shutdown().await?) - }, - } - } - - fn get_log_listener(&self) -> impl LogListener { - match self { - JsonRpcStdioTransport::Client { log_receiver, .. } => StdioLogListener { - receiver: log_receiver.resubscribe(), - }, - JsonRpcStdioTransport::Server { .. } => unreachable!("server does not need a log listener"), - } - } -} - -pub struct StdioListener { - pub receiver: broadcast::Receiver>, -} - -#[async_trait::async_trait] -impl Listener for StdioListener { - async fn recv(&mut self) -> Result { - self.receiver.recv().await? - } -} - -pub struct StdioLogListener { - pub receiver: broadcast::Receiver, -} - -#[async_trait::async_trait] -impl LogListener for StdioLogListener { - async fn recv(&mut self) -> Result { - Ok(self.receiver.recv().await?) - } -} - -#[cfg(test)] -mod tests { - use std::process::Stdio; - - use serde_json::{ - Value, - json, - }; - use tokio::process::Command; - - use crate::{ - JsonRpcMessage, - JsonRpcStdioTransport, - Listener, - Transport, - }; - - // Helpers for testing - fn create_test_message() -> JsonRpcMessage { - serde_json::from_value(json!({ - "jsonrpc": "2.0", - "id": 1, - "method": "test_method", - "params": { - "test_param": "test_value" - } - })) - .unwrap() - } - - #[tokio::test] - async fn test_client_transport() { - let mut cmd = Command::new("cat"); - cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).stderr(Stdio::piped()); - - // Inject our mock transport instead - let child = cmd.spawn().expect("Failed to spawn command"); - let transport = JsonRpcStdioTransport::client(child).expect("Failed to create client transport"); - - let message = create_test_message(); - let result = transport.send(&message).await; - assert!(result.is_ok(), "Failed to send message: {:?}", result); - - let echo = transport - .get_listener() - .recv() - .await - .expect("Failed to receive message"); - let echo_value = serde_json::to_value(&echo).expect("Failed to convert echo to value"); - let message_value = serde_json::to_value(&message).expect("Failed to convert message to value"); - assert!(are_json_values_equal(&echo_value, &message_value)); - } - - fn are_json_values_equal(a: &Value, b: &Value) -> bool { - match (a, b) { - (Value::Null, Value::Null) => true, - (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, - (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, - (Value::String(a_val), Value::String(b_val)) => a_val == b_val, - (Value::Array(a_arr), Value::Array(b_arr)) => { - if a_arr.len() != b_arr.len() { - return false; - } - a_arr - .iter() - .zip(b_arr.iter()) - .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) - }, - (Value::Object(a_obj), Value::Object(b_obj)) => { - if a_obj.len() != b_obj.len() { - return false; - } - a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { - Some(b_value) => are_json_values_equal(a_value, b_value), - None => false, - }) - }, - _ => false, - } - } -} diff --git a/crates/mcp_client/src/transport/websocket.rs b/crates/mcp_client/src/transport/websocket.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/mcp_client/test_mcp_server/test_server.rs b/crates/mcp_client/test_mcp_server/test_server.rs deleted file mode 100644 index 486048bad2..0000000000 --- a/crates/mcp_client/test_mcp_server/test_server.rs +++ /dev/null @@ -1,354 +0,0 @@ -//! This is a bin used solely for testing the client -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::atomic::{ - AtomicU8, - Ordering, -}; - -use mcp_client::server::{ - self, - PreServerRequestHandler, - Response, - ServerError, - ServerRequestHandler, -}; -use mcp_client::transport::{ - JsonRpcRequest, - JsonRpcResponse, - JsonRpcStdioTransport, -}; -use tokio::sync::Mutex; - -#[derive(Default)] -struct Handler { - pending_request: Option Option + Send + Sync>>, - #[allow(clippy::type_complexity)] - send_request: Option) -> Result<(), ServerError> + Send + Sync>>, - storage: Mutex>, - tool_spec: Mutex>, - tool_spec_key_list: Mutex>, - prompts: Mutex>, - prompt_key_list: Mutex>, - prompt_list_call_no: AtomicU8, -} - -impl PreServerRequestHandler for Handler { - fn register_pending_request_callback( - &mut self, - cb: impl Fn(u64) -> Option + Send + Sync + 'static, - ) { - self.pending_request = Some(Box::new(cb)); - } - - fn register_send_request_callback( - &mut self, - cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, - ) { - self.send_request = Some(Box::new(cb)); - } -} - -#[async_trait::async_trait] -impl ServerRequestHandler for Handler { - async fn handle_initialize(&self, params: Option) -> Result { - let mut storage = self.storage.lock().await; - if let Some(params) = params { - storage.insert("client_cap".to_owned(), params); - } - let capabilities = serde_json::json!({ - "protocolVersion": "2024-11-05", - "capabilities": { - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "subscribe": true, - "listChanged": true - }, - "tools": { - "listChanged": true - } - }, - "serverInfo": { - "name": "TestServer", - "version": "1.0.0" - } - }); - Ok(Some(capabilities)) - } - - async fn handle_incoming(&self, method: &str, params: Option) -> Result { - match method { - "notifications/initialized" => { - { - let mut storage = self.storage.lock().await; - storage.insert( - "init_ack_sent".to_owned(), - serde_json::Value::from_str("true").expect("Failed to convert string to value"), - ); - } - Ok(None) - }, - "verify_init_params_sent" => { - let client_capabilities = { - let storage = self.storage.lock().await; - storage.get("client_cap").cloned() - }; - Ok(client_capabilities) - }, - "verify_init_ack_sent" => { - let result = { - let storage = self.storage.lock().await; - storage.get("init_ack_sent").cloned() - }; - Ok(result) - }, - "store_mock_tool_spec" => { - let Some(params) = params else { - eprintln!("Params missing from store mock tool spec"); - return Ok(None); - }; - // expecting a mock_specs: { key: String, value: serde_json::Value }[]; - let Ok(mock_specs) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let self_tool_specs = self.tool_spec.lock().await; - let mut self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let _ = mock_specs.iter().fold(self_tool_specs, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_tool_spec_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - Ok(None) - }, - "tools/list" => { - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_tool_spec_key_list = self.tool_spec_key_list.lock().await; - let self_tool_spec = self.tool_spec.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_tool_spec_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_tool_spec_key_list.get(i + 1).cloned(), - self_tool_spec.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "tools": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - let first_key = self - .tool_spec_key_list - .lock() - .await - .first() - .expect("First key missing from tool specs") - .clone(); - let first_value = self - .tool_spec - .lock() - .await - .get(&first_key) - .expect("First value missing from tool specs") - .clone(); - let second_key = self - .tool_spec_key_list - .lock() - .await - .get(1) - .expect("Second key missing from tool specs") - .clone(); - return Ok(Some(serde_json::json!({ - "tools": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_env_vars" => { - let kv = std::env::vars().fold(HashMap::::new(), |mut acc, (k, v)| { - acc.insert(k, v); - acc - }); - Ok(Some(serde_json::json!(kv))) - }, - // This is a test path relevant only to sampling - "trigger_server_request" => { - let Some(ref send_request) = self.send_request else { - return Err(ServerError::MissingMethod); - }; - let params = Some(serde_json::json!({ - "messages": [ - { - "role": "user", - "content": { - "type": "text", - "text": "What is the capital of France?" - } - } - ], - "modelPreferences": { - "hints": [ - { - "name": "claude-3-sonnet" - } - ], - "intelligencePriority": 0.8, - "speedPriority": 0.5 - }, - "systemPrompt": "You are a helpful assistant.", - "maxTokens": 100 - })); - send_request("sampling/createMessage", params)?; - Ok(None) - }, - "store_mock_prompts" => { - let Some(params) = params else { - eprintln!("Params missing from store mock prompts"); - return Ok(None); - }; - // expecting a mock_prompts: { key: String, value: serde_json::Value }[]; - let Ok(mock_prompts) = serde_json::from_value::>(params) else { - eprintln!("Failed to convert to mock specs from value"); - return Ok(None); - }; - let self_prompts = self.prompts.lock().await; - let mut self_prompt_key_list = self.prompt_key_list.lock().await; - let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| { - let Some(key) = spec.get("key").cloned() else { - return acc; - }; - let Ok(key) = serde_json::from_value::(key) else { - eprintln!("Failed to convert serde value to string for key"); - return acc; - }; - self_prompt_key_list.push(key.clone()); - acc.insert(key, spec.get("value").cloned()); - acc - }); - Ok(None) - }, - "prompts/list" => { - self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed); - if let Some(params) = params { - if let Some(cursor) = params.get("cursor").cloned() { - let Ok(cursor) = serde_json::from_value::(cursor) else { - eprintln!("Failed to convert cursor to string: {:#?}", params); - return Ok(None); - }; - let self_prompt_key_list = self.prompt_key_list.lock().await; - let self_prompts = self.prompts.lock().await; - let (next_cursor, spec) = { - 'blk: { - for (i, item) in self_prompt_key_list.iter().enumerate() { - if item == &cursor { - break 'blk ( - self_prompt_key_list.get(i + 1).cloned(), - self_prompts.get(&cursor).cloned().unwrap(), - ); - } - } - (None, None) - } - }; - if let Some(next_cursor) = next_cursor { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - "nextCursor": next_cursor, - }))); - } else { - return Ok(Some(serde_json::json!({ - "prompts": [spec.unwrap()], - }))); - } - } else { - eprintln!("Params exist but cursor is missing"); - return Ok(None); - } - } else { - let first_key = self - .prompt_key_list - .lock() - .await - .first() - .expect("First key missing from prompts") - .clone(); - let first_value = self - .prompts - .lock() - .await - .get(&first_key) - .expect("First value missing from prompts") - .clone(); - let second_key = self - .prompt_key_list - .lock() - .await - .get(1) - .expect("Second key missing from prompts") - .clone(); - return Ok(Some(serde_json::json!({ - "prompts": [first_value], - "nextCursor": second_key - }))); - }; - }, - "get_prompt_list_call_no" => Ok(Some( - serde_json::to_value::(self.prompt_list_call_no.load(Ordering::Relaxed)) - .expect("Failed to convert list call no to u8"), - )), - _ => Err(ServerError::MissingMethod), - } - } - - // This is a test path relevant only to sampling - async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError> { - let JsonRpcResponse { id, .. } = resp; - let _pending = self.pending_request.as_ref().and_then(|f| f(id)); - Ok(()) - } - - async fn handle_shutdown(&self) -> Result<(), ServerError> { - Ok(()) - } -} - -#[tokio::main] -async fn main() { - let handler = Handler::default(); - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); - let test_server = - server::Server::::new(handler, stdin, stdout).expect("Failed to create server"); - let _ = test_server.init().expect("Test server failed to init").await; -} diff --git a/crates/q_cli/Cargo.toml b/crates/q_cli/Cargo.toml index 4460fa8ce0..9401547269 100644 --- a/crates/q_cli/Cargo.toml +++ b/crates/q_cli/Cargo.toml @@ -56,7 +56,6 @@ glob.workspace = true globset.workspace = true indicatif.workspace = true indoc.workspace = true -mcp_client.workspace = true mimalloc.workspace = true owo-colors = "4.2.0" parking_lot.workspace = true From adf76a5fdd21d4ae30e8bbdd6112b7601a257647 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Tue, 6 May 2025 17:57:53 -0700 Subject: [PATCH 08/26] makes initial server loading interruptable --- crates/chat-cli/src/cli/chat/mod.rs | 1 + .../{tool_manager => }/server_messenger.rs | 0 .../chat/{tool_manager => }/tool_manager.rs | 280 ++++++++++-------- .../chat-cli/src/cli/chat/tool_manager/mod.rs | 4 - 4 files changed, 150 insertions(+), 135 deletions(-) rename crates/chat-cli/src/cli/chat/{tool_manager => }/server_messenger.rs (100%) rename crates/chat-cli/src/cli/chat/{tool_manager => }/tool_manager.rs (86%) delete mode 100644 crates/chat-cli/src/cli/chat/tool_manager/mod.rs diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 42febea965..10c1e067cb 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -9,6 +9,7 @@ mod message; mod parse; mod parser; mod prompt; +mod server_messenger; #[cfg(unix)] mod skim_integration; mod token_counter; diff --git a/crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs similarity index 100% rename from crates/chat-cli/src/cli/chat/tool_manager/server_messenger.rs rename to crates/chat-cli/src/cli/chat/server_messenger.rs diff --git a/crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs similarity index 86% rename from crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs rename to crates/chat-cli/src/cli/chat/tool_manager.rs index cafb4977c1..e9ef2c2ab3 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -1,10 +1,12 @@ use std::collections::HashMap; +use std::future::Future; use std::hash::{ DefaultHasher, Hasher, }; use std::io::Write; use std::path::PathBuf; +use std::pin::Pin; use std::sync::atomic::AtomicBool; use std::sync::mpsc::RecvTimeoutError; use std::sync::{ @@ -22,6 +24,7 @@ use crossterm::{ }; use futures::{ StreamExt, + future, stream, }; use regex::Regex; @@ -30,6 +33,7 @@ use serde::{ Serialize, }; use thiserror::Error; +use tokio::signal::ctrl_c; use tokio::sync::Mutex; use tracing::{ error, @@ -43,7 +47,7 @@ use crate::api_client::model::{ }; use crate::cli::chat::command::PromptsGetCommand; use crate::cli::chat::message::AssistantToolUse; -use crate::cli::chat::tool_manager::server_messenger::{ +use crate::cli::chat::server_messenger::{ ServerMessengerBuilder, UpdateEventMessage, }; @@ -124,6 +128,7 @@ pub enum LoadingMsg { /// * `init_time` - When initialization for this tool began, used to calculate load time struct StatusLine { init_time: std::time::Instant, + is_done: bool, } // This is to mirror claude's config set up @@ -249,7 +254,8 @@ impl ToolManagerBuilder { Ok(recv_result) => match recv_result { LoadingMsg::Add(name) => { let init_time = std::time::Instant::now(); - let status_line = StatusLine { init_time }; + let is_done = false; + let status_line = StatusLine { init_time, is_done }; execute!(stdout_lock, cursor::MoveToColumn(0))?; if !loading_servers.is_empty() { // TODO: account for terminal width @@ -262,7 +268,8 @@ impl ToolManagerBuilder { stdout_lock.flush()?; }, LoadingMsg::Done(name) => { - if let Some(status_line) = loading_servers.get(&name) { + if let Some(status_line) = loading_servers.get_mut(&name) { + status_line.is_done = true; complete += 1; let time_taken = (std::time::Instant::now() - status_line.init_time).as_secs_f64().abs(); @@ -278,41 +285,68 @@ impl ToolManagerBuilder { queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; stdout_lock.flush()?; } + if loading_servers.iter().all(|(_, status)| status.is_done) { + break; + } }, LoadingMsg::Error { name, msg } => { - failed += 1; - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_failure_message(&name, &msg, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + if let Some(status_line) = loading_servers.get_mut(&name) { + status_line.is_done = true; + failed += 1; + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_failure_message(&name, &msg, &mut stdout_lock)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + } + if loading_servers.iter().all(|(_, status)| status.is_done) { + break; + } }, LoadingMsg::Warn { name, msg } => { - complete += 1; - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = eyre::eyre!(msg.to_string()); - queue_warn_message(&name, &msg, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; + if let Some(status_line) = loading_servers.get_mut(&name) { + status_line.is_done = true; + complete += 1; + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = eyre::eyre!(msg.to_string()); + queue_warn_message(&name, &msg, &mut stdout_lock)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + stdout_lock.flush()?; + } + if loading_servers.iter().all(|(_, status)| status.is_done) { + break; + } }, LoadingMsg::Terminate => { - if !loading_servers.is_empty() { - let msg = loading_servers.iter().fold(String::new(), |mut acc, (server_name, _)| { - acc.push_str(format!("\n - {server_name}").as_str()); - acc - }); + if loading_servers.iter().any(|(_, status)| !status.is_done) { + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = + loading_servers + .iter() + .fold(String::new(), |mut acc, (server_name, status)| { + if !status.is_done { + acc.push_str(format!("\n - {server_name}").as_str()); + } + acc + }); let msg = eyre::eyre!(msg); queue_incomplete_load_message(&msg, &mut stdout_lock)?; + stdout_lock.flush()?; } break; }, @@ -334,12 +368,13 @@ impl ToolManagerBuilder { Ok::<_, eyre::Report>(()) }); let mut clients = HashMap::>::new(); - let load_msg_sender = tx.clone(); + let mut load_msg_sender = Some(tx.clone()); let conv_id_clone = conversation_id.clone(); let regex = Arc::new(Regex::new(VALID_TOOL_NAME)?); + let new_tool_specs = Arc::new(Mutex::new(HashMap::new())); + let new_tool_specs_clone = new_tool_specs.clone(); let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); tokio::spawn(async move { - let mut is_in_display = true; while let Some(msg) = msg_rx.recv().await { // For now we will treat every list result as if they contain the // complete set of tools. This is not necessarily true in the future when @@ -347,7 +382,6 @@ impl ToolManagerBuilder { // list calls. match msg { UpdateEventMessage::ToolsListResult { server_name, result } => { - error!("## background: from {server_name}: {:?}", result); let mut specs = result .tools .into_iter() @@ -357,25 +391,44 @@ impl ToolManagerBuilder { if let Some(load_msg) = process_tool_specs( conv_id_clone.as_str(), &server_name, - is_in_display, + load_msg_sender.is_some(), &mut specs, &mut sanitized_mapping, ®ex, ) { - if let Err(e) = load_msg_sender.send(load_msg) { - warn!( - "Error sending update message to display task: {:?}\nAssume display task has completed", - e - ); - is_in_display = false; + let mut has_errored = false; + if let Some(sender) = &load_msg_sender { + if let Err(e) = sender.send(load_msg) { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + has_errored = true; + } + } + if has_errored { + load_msg_sender.take(); } } + new_tool_specs_clone + .lock() + .await + .insert(server_name, (sanitized_mapping, specs)); }, - UpdateEventMessage::PromptsListResult { server_name, result } => {}, - UpdateEventMessage::ResourcesListResult { server_name, result } => {}, - UpdateEventMessage::ResourceTemplatesListResult { server_name, result } => {}, + UpdateEventMessage::PromptsListResult { + server_name: _, + result: _, + } => {}, + UpdateEventMessage::ResourcesListResult { + server_name: _, + result: _, + } => {}, + UpdateEventMessage::ResourceTemplatesListResult { + server_name: _, + result: _, + } => {}, UpdateEventMessage::DisplayTaskEnded => { - is_in_display = false; + load_msg_sender.take(); }, } } @@ -503,6 +556,7 @@ impl ToolManagerBuilder { prompts, loading_display_task, loading_status_sender, + new_tool_specs, ..Default::default() }) } @@ -579,17 +633,14 @@ impl ToolManager { pub async fn load_tools(&mut self) -> eyre::Result> { let tx = self.loading_status_sender.take(); let display_task = self.loading_display_task.take(); - let tool_specs = { + let mut tool_specs = { let mut tool_specs = - serde_json::from_str::>(include_str!("../tools/tool_index.json"))?; + serde_json::from_str::>(include_str!("tools/tool_index.json"))?; if !crate::cli::chat::tools::thinking::Thinking::is_enabled() { tool_specs.remove("q_think_tool"); } - Arc::new(Mutex::new(tool_specs)) + tool_specs }; - let conversation_id = self.conversation_id.clone(); - let regex = Arc::new(regex::Regex::new(VALID_TOOL_NAME)?); - self.new_tool_specs = Arc::new(Mutex::new(HashMap::new())); let load_tools = self .clients .values() @@ -598,90 +649,57 @@ impl ToolManager { async move { clone.init().await } }) .collect::>(); - let some = stream::iter(load_tools) + let initial_poll = stream::iter(load_tools) .map(|async_closure| tokio::spawn(async_closure)) - .buffer_unordered(20) - .collect::>() - .await; - // let load_tool = self - // .clients - // .iter() - // .map(|(server_name, client)| { - // let client_clone = client.clone(); - // let server_name_clone = server_name.clone(); - // let tx_clone = tx.clone(); - // let regex_clone = regex.clone(); - // let tool_specs_clone = tool_specs.clone(); - // let conversation_id = conversation_id.clone(); - // async move { - // let tool_spec = client_clone.init().await; - // let mut sanitized_mapping = HashMap::::new(); - // match tool_spec { - // Ok((server_name, mut specs)) => { - // let msg = process_tool_specs( - // conversation_id.as_str(), - // &server_name, - // true, - // &mut specs, - // &mut sanitized_mapping, - // ®ex_clone, - // ); - // for spec in specs { - // tool_specs_clone.lock().await.insert(spec.name.clone(), spec); - // } - // if let (Some(msg), Some(tx)) = (msg, &tx_clone) { - // let _ = tx.send(msg); - // } - // }, - // Err(e) => { - // error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); - // let init_failure_reason = Some(e.to_string()); - // tokio::spawn(async move { - // let event = fig_telemetry::EventType::McpServerInit { - // conversation_id, - // init_failure_reason, - // number_of_tools: 0, - // }; - // let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; - // fig_telemetry::dispatch_or_send_event(app_event).await; - // }); - // if let Some(tx_clone) = &tx_clone { - // if let Err(e) = tx_clone.send(LoadingMsg::Error { - // name: server_name_clone, - // msg: e, - // }) { - // error!("Error while sending status update to display task: {:?}", e); - // } - // } - // }, - // } - // Ok::<_, eyre::Report>(Some(sanitized_mapping)) - // } - // }) - // .collect::>(); - // // TODO: do we want to introduce a timeout here? - // self.tn_map = stream::iter(load_tool) - // .map(|async_closure| tokio::task::spawn(async_closure)) - // .buffer_unordered(20) - // .collect::>() - // .await - // .into_iter() - // .filter_map(|r| r.ok()) - // .filter_map(|r| r.ok()) - // .flatten() - // .flatten() - // .collect::>(); - drop(tx); - if let Some(display_task) = display_task { - if let Err(e) = display_task.await { - error!("Error while joining status display task: {:?}", e); + .buffer_unordered(20); + tokio::spawn(async move { + initial_poll.collect::>().await; + }); + // We need to cast it to erase the type otherwise the compiler will default to static + // dispatch, which would result in an error of inconsistent match arm return type. + let display_future: Pin>> = match display_task { + Some(display_task) => { + let fut = async move { + if let Err(e) = display_task.await { + error!("Error while joining status display task: {:?}", e); + } + }; + Box::pin(fut) + }, + None => { + let fut = async { future::pending::<()>().await }; + Box::pin(fut) + }, + }; + tokio::select! { + _ = display_future => {}, + // TODO: make this timeout configurable + _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => { + if let Some(tx) = tx { + let _ = tx.send(LoadingMsg::Terminate); + } + }, + _ = ctrl_c() => { + if let Some(tx) = tx { + let _ = tx.send(LoadingMsg::Terminate); + } } } - let tool_specs = { - let mutex = - Arc::try_unwrap(tool_specs).map_err(|e| eyre::eyre!("Error unwrapping arc for tool specs {:?}", e))?; - mutex.into_inner() + let new_tools = { + let mut new_tool_specs = self.new_tool_specs.lock().await; + new_tool_specs.drain().fold(HashMap::new(), |mut acc, (k, v)| { + acc.insert(k, v); + acc + }) }; + for (_server_name, (tool_name_map, specs)) in new_tools { + for (k, v) in tool_name_map { + self.tn_map.insert(k, v); + } + for spec in specs { + tool_specs.insert(spec.name.clone(), spec); + } + } // caching the tool names for skim operations for tool_name in tool_specs.keys() { if !self.tn_map.contains_key(tool_name) { @@ -1104,7 +1122,7 @@ fn queue_init_message( style::SetForegroundColor(style::Color::Blue), style::Print(format!("{} ", total)), style::ResetColor, - style::Print("mcp servers initialized\n"), + style::Print("mcp servers initialized. Press ctrl-c to load the remaining servers in the background\n"), )?) } diff --git a/crates/chat-cli/src/cli/chat/tool_manager/mod.rs b/crates/chat-cli/src/cli/chat/tool_manager/mod.rs deleted file mode 100644 index 6251b7fb77..0000000000 --- a/crates/chat-cli/src/cli/chat/tool_manager/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod server_messenger; -pub mod tool_manager; - -pub use tool_manager::*; From d6ca3a1525fa3b0d423277f262fc25c59a8bc2db Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Wed, 7 May 2025 10:29:30 -0700 Subject: [PATCH 09/26] formats --- crates/chat-cli/src/mcp_client/error.rs | 6 +++--- crates/chat-cli/src/mcp_client/mod.rs | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/crates/chat-cli/src/mcp_client/error.rs b/crates/chat-cli/src/mcp_client/error.rs index d05e7efa4d..01f77cfa8b 100644 --- a/crates/chat-cli/src/mcp_client/error.rs +++ b/crates/chat-cli/src/mcp_client/error.rs @@ -29,7 +29,7 @@ pub enum ErrorCode { /// Unknown error code. /// This error is returned when an error code is received that is not /// recognized by the implementation. - UnknownErrorCode = -32001, + Unknown = -32001, /// Request failed. /// This error is returned when a request fails for a reason not covered @@ -46,9 +46,9 @@ impl From for ErrorCode { -32602 => ErrorCode::InvalidParams, -32603 => ErrorCode::InternalError, -32002 => ErrorCode::ServerNotInitialized, - -32001 => ErrorCode::UnknownErrorCode, + -32001 => ErrorCode::Unknown, -32000 => ErrorCode::RequestFailed, - _ => ErrorCode::UnknownErrorCode, + _ => ErrorCode::Unknown, } } } diff --git a/crates/chat-cli/src/mcp_client/mod.rs b/crates/chat-cli/src/mcp_client/mod.rs index 465dcf6cec..19f23b809a 100644 --- a/crates/chat-cli/src/mcp_client/mod.rs +++ b/crates/chat-cli/src/mcp_client/mod.rs @@ -8,5 +8,4 @@ pub mod transport; pub use client::*; pub use facilitator_types::*; pub use messenger::*; -pub use server::*; pub use transport::*; From 53ee062b1073eff91ad95f33ed7083158b5b50a1 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Thu, 8 May 2025 10:53:41 -0700 Subject: [PATCH 10/26] adds atomic bool to signal when new things are added --- crates/chat-cli/src/cli/chat/tool_manager.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index e9ef2c2ab3..96be8c7745 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -7,7 +7,10 @@ use std::hash::{ use std::io::Write; use std::path::PathBuf; use std::pin::Pin; -use std::sync::atomic::AtomicBool; +use std::sync::atomic::{ + AtomicBool, + Ordering, +}; use std::sync::mpsc::RecvTimeoutError; use std::sync::{ Arc, @@ -373,6 +376,8 @@ impl ToolManagerBuilder { let regex = Arc::new(Regex::new(VALID_TOOL_NAME)?); let new_tool_specs = Arc::new(Mutex::new(HashMap::new())); let new_tool_specs_clone = new_tool_specs.clone(); + let has_new_stuff = Arc::new(AtomicBool::new(false)); + let has_new_stuff_clone = has_new_stuff.clone(); let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); tokio::spawn(async move { while let Some(msg) = msg_rx.recv().await { @@ -414,6 +419,10 @@ impl ToolManagerBuilder { .lock() .await .insert(server_name, (sanitized_mapping, specs)); + // We only want to set this flag when the display task has ended + if load_msg_sender.is_none() { + has_new_stuff_clone.store(true, Ordering::Relaxed); + } }, UpdateEventMessage::PromptsListResult { server_name: _, @@ -557,6 +566,7 @@ impl ToolManagerBuilder { loading_display_task, loading_status_sender, new_tool_specs, + has_new_stuff, ..Default::default() }) } From f9b0891e1aea2c2c5d6e847fd9b59acb9a69e796 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Thu, 8 May 2025 12:07:09 -0700 Subject: [PATCH 11/26] moves tool manager to conversation state --- .../src/cli/chat/conversation_state.rs | 15 ++++++++-- crates/chat-cli/src/cli/chat/mod.rs | 30 +++++++++++-------- .../chat-cli/src/cli/chat/server_messenger.rs | 1 + crates/chat-cli/src/cli/chat/tool_manager.rs | 17 ++++++++++- .../src/cli/chat/tools/custom_tool.rs | 23 +++----------- crates/chat-cli/src/mcp_client/client.rs | 5 +--- .../src/mcp_client/facilitator_types.rs | 1 + crates/chat-cli/src/mcp_client/messenger.rs | 1 + crates/chat-cli/src/mcp_client/server.rs | 1 + .../chat-cli/src/mcp_client/transport/mod.rs | 1 + 10 files changed, 57 insertions(+), 38 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index 1f22c4013f..876c6baece 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -37,6 +37,7 @@ use super::token_counter::{ CharCount, CharCounter, }; +use super::tool_manager::ToolManager; use super::tools::{ InputSchema, QueuedTool, @@ -85,6 +86,8 @@ pub struct ConversationState { pub tools: HashMap>, /// Context manager for handling sticky context files pub context_manager: Option, + /// Tool manager for handling tool and mcp related activities + pub tool_manager: ToolManager, /// Cached value representing the length of the user context message. context_message_length: Option, /// Stores the latest conversation summary created by /compact @@ -99,6 +102,7 @@ impl ConversationState { tool_config: HashMap, profile: Option, updates: Option, + tool_manager: ToolManager, ) -> Self { // Initialize context manager let context_manager = match ContextManager::new(ctx, None).await { @@ -137,6 +141,7 @@ impl ConversationState { acc }), context_manager, + tool_manager, context_message_length: None, latest_summary: None, updates, @@ -926,6 +931,7 @@ mod tests { tool_manager.load_tools().await.unwrap(), None, None, + tool_manager, ) .await; @@ -944,12 +950,14 @@ mod tests { async fn test_conversation_state_history_handling_with_tool_results() { // Build a long conversation history of tool use results. let mut tool_manager = ToolManager::default(); + let tool_config = tool_manager.load_tools().await.unwrap(); let mut conversation_state = ConversationState::new( Context::new(), "fake_conv_id", - tool_manager.load_tools().await.unwrap(), + tool_config.clone(), None, None, + tool_manager.clone(), ) .await; conversation_state.set_next_user_message("start".to_string()).await; @@ -975,9 +983,10 @@ mod tests { let mut conversation_state = ConversationState::new( Context::new(), "fake_conv_id", - tool_manager.load_tools().await.unwrap(), + tool_config.clone(), None, None, + tool_manager.clone(), ) .await; conversation_state.set_next_user_message("start".to_string()).await; @@ -1016,6 +1025,7 @@ mod tests { tool_manager.load_tools().await.unwrap(), None, None, + tool_manager, ) .await; @@ -1081,6 +1091,7 @@ mod tests { tool_manager.load_tools().await.unwrap(), None, Some(SharedWriter::stdout()), + tool_manager, ) .await; diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index e15b5ba97a..d989968c07 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -505,8 +505,6 @@ pub struct ChatContext { tool_use_telemetry_events: HashMap, /// State used to keep track of tool use relation tool_use_status: ToolUseStatus, - /// Abstraction that consolidates custom tools with native ones - tool_manager: ToolManager, /// Any failed requests that could be useful for error report/debugging failed_request_ids: Vec, /// Pending prompts to be sent @@ -533,8 +531,15 @@ impl ChatContext { ) -> Result { let ctx_clone = Arc::clone(&ctx); let output_clone = output.clone(); - let conversation_state = - ConversationState::new(ctx_clone, conversation_id, tool_config, profile, Some(output_clone)).await; + let conversation_state = ConversationState::new( + ctx_clone, + conversation_id, + tool_config, + profile, + Some(output_clone), + tool_manager, + ) + .await; Ok(Self { ctx, settings, @@ -550,7 +555,6 @@ impl ChatContext { conversation_state, tool_use_telemetry_events: HashMap::new(), tool_use_status: ToolUseStatus::Idle, - tool_manager, failed_request_ids: Vec::new(), pending_prompts: VecDeque::new(), }) @@ -1217,6 +1221,7 @@ impl ChatContext { #[cfg(unix)] if let Some(ref context_manager) = self.conversation_state.context_manager { let tool_names = self + .conversation_state .tool_manager .tn_map .keys() @@ -2152,9 +2157,10 @@ impl ChatContext { match subcommand { Some(ToolsSubcommand::Schema) => { - let schema_json = serde_json::to_string_pretty(&self.tool_manager.schema).map_err(|e| { - ChatError::Custom(format!("Error converting tool schema to string: {e}").into()) - })?; + let schema_json = serde_json::to_string_pretty(&self.conversation_state.tool_manager.schema) + .map_err(|e| { + ChatError::Custom(format!("Error converting tool schema to string: {e}").into()) + })?; queue!(self.output, style::Print(schema_json), style::Print("\n"))?; }, Some(ToolsSubcommand::Trust { tool_names }) => { @@ -2368,7 +2374,7 @@ impl ChatContext { }, Some(PromptsSubcommand::Get { mut get_command }) => { let orig_input = get_command.orig_input.take(); - let prompts = match self.tool_manager.get_prompt(get_command).await { + let prompts = match self.conversation_state.tool_manager.get_prompt(get_command).await { Ok(resp) => resp, Err(e) => { match e { @@ -2455,12 +2461,12 @@ impl ChatContext { _ => None, }; let terminal_width = self.terminal_width(); - let mut prompts_wl = self.tool_manager.prompts.write().map_err(|e| { + let mut prompts_wl = self.conversation_state.tool_manager.prompts.write().map_err(|e| { ChatError::Custom( format!("Poison error encountered while retrieving prompts: {}", e).into(), ) })?; - self.tool_manager.refresh_prompts(&mut prompts_wl)?; + self.conversation_state.tool_manager.refresh_prompts(&mut prompts_wl)?; let mut longest_name = ""; let arg_pos = { let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4; @@ -3126,7 +3132,7 @@ impl ChatContext { .set_tool_use_id(tool_use_id.clone()) .set_tool_name(tool_use.name.clone()) .utterance_id(self.conversation_state.message_id().map(|s| s.to_string())); - match self.tool_manager.get_tool_from_tool_use(tool_use) { + match self.conversation_state.tool_manager.get_tool_from_tool_use(tool_use) { Ok(mut tool) => { // Apply non-Q-generated context to tools self.contextualize_tool(&mut tool); diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs index cdca50d8f7..e56cbd0715 100644 --- a/crates/chat-cli/src/cli/chat/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/server_messenger.rs @@ -13,6 +13,7 @@ use crate::mcp_client::{ ToolsListResult, }; +#[allow(dead_code)] #[derive(Clone, Debug)] pub enum UpdateEventMessage { ToolsListResult { diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index d73eaaf484..a46eefccff 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -596,7 +596,7 @@ enum OutOfSpecName { type NewToolSpecs = Arc, Vec)>>>; -#[derive(Default)] +#[derive(Default, Debug)] /// Manages the lifecycle and interactions with tools from various sources, including MCP servers. /// This struct is responsible for initializing tools, handling tool requests, and maintaining /// a cache of available prompts from connected servers. @@ -639,6 +639,21 @@ pub struct ToolManager { pub schema: HashMap, } +impl Clone for ToolManager { + fn clone(&self) -> Self { + Self { + conversation_id: self.conversation_id.clone(), + clients: self.clients.clone(), + has_new_stuff: self.has_new_stuff.clone(), + new_tool_specs: self.new_tool_specs.clone(), + prompts: self.prompts.clone(), + tn_map: self.tn_map.clone(), + schema: self.schema.clone(), + ..Default::default() + } + } +} + impl ToolManager { pub async fn load_tools(&mut self) -> eyre::Result> { let tx = self.loading_status_sender.take(); diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index a6fbca2586..24f6d0f364 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -15,10 +15,7 @@ use serde::{ use tokio::sync::RwLock; use tracing::warn; -use super::{ - InvokeOutput, - ToolSpec, -}; +use super::InvokeOutput; use crate::cli::chat::CONTINUATION_LINE; use crate::cli::chat::token_counter::TokenCounter; use crate::mcp_client::{ @@ -88,32 +85,20 @@ impl CustomToolClient { }) } - pub async fn init(&self) -> Result<(String, Vec)> { + pub async fn init(&self) -> Result<()> { match self { CustomToolClient::Stdio { client, - server_name, server_capabilities, + .. } => { // We'll need to first initialize. This is the handshake every client and server // needs to do before proceeding to anything else let cap = client.init().await?; // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 // So don't worry about the tidiness for now - let is_tool_supported = cap.tools.is_some(); server_capabilities.write().await.replace(cap); - // Assuming a shape of return as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#listing-tools - // let tools = if is_tool_supported { - // // And now we make the server tell us what tools they have - // let resp = client.request("tools/list", None).await?; - // match resp.result.and_then(|r| r.get("tools").cloned()) { - // Some(value) => serde_json::from_value::>(value)?, - // None => Default::default(), - // } - // } else { - // Default::default() - // }; - Ok((server_name.clone(), vec![])) + Ok(()) }, } } diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 15ce904b91..978edc8d5d 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -82,6 +82,7 @@ pub struct ClientConfig { pub env: Option>, } +#[allow(dead_code)] #[derive(Debug, Error)] pub enum ClientError { #[error(transparent)] @@ -548,10 +549,6 @@ where ) } - pub async fn shutdown(&self) -> Result<(), ClientError> { - Ok(self.transport.shutdown().await?) - } - fn get_id(&self) -> u64 { self.current_id.fetch_add(1, Ordering::SeqCst) } diff --git a/crates/chat-cli/src/mcp_client/facilitator_types.rs b/crates/chat-cli/src/mcp_client/facilitator_types.rs index 908f555bd2..87fbd79b27 100644 --- a/crates/chat-cli/src/mcp_client/facilitator_types.rs +++ b/crates/chat-cli/src/mcp_client/facilitator_types.rs @@ -5,6 +5,7 @@ use serde::{ use thiserror::Error; /// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination +#[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, PartialEq, Eq)] pub enum PaginationSupportedOps { ResourcesList, diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs index caa6cf20e2..1d4f361445 100644 --- a/crates/chat-cli/src/mcp_client/messenger.rs +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -11,6 +11,7 @@ use super::{ /// consumer. It is through this interface secondary information (i.e. information that are needed /// to make requests to mcp servers) are obtained passively. Consumers of client can of course /// choose to "actively" retrieve these information via explicitly making these requests. +#[allow(dead_code)] #[async_trait::async_trait] pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { /// Sends the result of a tools list operation to the consumer diff --git a/crates/chat-cli/src/mcp_client/server.rs b/crates/chat-cli/src/mcp_client/server.rs index 0b251f1ccf..73dfae2dd3 100644 --- a/crates/chat-cli/src/mcp_client/server.rs +++ b/crates/chat-cli/src/mcp_client/server.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] use std::collections::HashMap; use std::sync::atomic::{ AtomicBool, diff --git a/crates/chat-cli/src/mcp_client/transport/mod.rs b/crates/chat-cli/src/mcp_client/transport/mod.rs index 5796ba5323..f752b1675a 100644 --- a/crates/chat-cli/src/mcp_client/transport/mod.rs +++ b/crates/chat-cli/src/mcp_client/transport/mod.rs @@ -31,6 +31,7 @@ impl From for TransportError { } } +#[allow(dead_code)] #[async_trait::async_trait] pub trait Transport: Send + Sync + Debug + 'static { /// Sends a message over the transport layer. From 6dbca6c03814d6bcf9b8dea1fc4940448351d8e6 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Thu, 8 May 2025 18:09:27 -0700 Subject: [PATCH 12/26] makes main chat loop update state if applicable --- .../src/cli/chat/conversation_state.rs | 26 +++++++ crates/chat-cli/src/cli/chat/mod.rs | 11 ++- crates/chat-cli/src/cli/chat/tool_manager.rs | 76 +++++++++++++------ crates/chat-cli/src/cli/chat/tools/mod.rs | 15 +++- 4 files changed, 93 insertions(+), 35 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index 876c6baece..20b0c64f1b 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -3,6 +3,7 @@ use std::collections::{ VecDeque, }; use std::sync::Arc; +use std::sync::atomic::Ordering; use crossterm::style::Color; use crossterm::{ @@ -354,6 +355,7 @@ impl ConversationState { /// - `run_hooks` - whether hooks should be executed and included as context pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState { debug_assert!(self.next_message.is_some()); + self.update_state().await; self.enforce_conversation_invariants(); self.history.drain(self.valid_history_range.1..); self.history.drain(..self.valid_history_range.0); @@ -379,6 +381,30 @@ impl ConversationState { .expect("unable to construct conversation state") } + pub async fn update_state(&mut self) { + let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire); + if !needs_update { + return; + } + self.tool_manager.update().await; + self.tools = self + .tool_manager + .schema + .values() + .fold(HashMap::>::new(), |mut acc, v| { + let tool = Tool::ToolSpecification(ToolSpecification { + name: v.name.clone(), + description: v.description.clone(), + input_schema: v.input_schema.clone().into(), + }); + acc.entry(v.tool_origin.clone()) + .and_modify(|tools| tools.push(tool.clone())) + .or_insert(vec![tool]); + acc + }); + self.tool_manager.has_new_stuff.store(false, Ordering::Release); + } + /// Returns a conversation state representation which reflects the exact conversation to send /// back to the model. pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> { diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index d989968c07..8423c7aed8 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -401,6 +401,7 @@ pub async fn chat( let tool_config = tool_manager.load_tools().await?; let mut tool_permissions = ToolPermissions::new(tool_config.len()); if accept_all || trust_all_tools { + tool_permissions.trust_all = true; for tool in tool_config.values() { tool_permissions.trust_tool(&tool.name); } @@ -2259,7 +2260,7 @@ impl ChatContext { )?; }, Some(ToolsSubcommand::ResetSingle { tool_name }) => { - if self.tool_permissions.has(&tool_name) { + if self.tool_permissions.has(&tool_name) || self.tool_permissions.trust_all { self.tool_permissions.reset_tool(&tool_name); queue!( self.output, @@ -2735,11 +2736,9 @@ impl ChatContext { } // If there is an override, we will use it. Otherwise fall back to Tool's default. - let allowed = if self.tool_permissions.has(&tool.name) { - self.tool_permissions.is_trusted(&tool.name) - } else { - !tool.tool.requires_acceptance(&self.ctx) - }; + let allowed = self.tool_permissions.trust_all + || (self.tool_permissions.has(&tool.name) && self.tool_permissions.is_trusted(&tool.name)) + || !tool.tool.requires_acceptance(&self.ctx); if self.settings.get_bool_or("chat.enableNotifications", false) { play_notification_bell(!allowed); diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index a46eefccff..00dd2176ac 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -1,4 +1,7 @@ -use std::collections::HashMap; +use std::collections::{ + HashMap, + HashSet, +}; use std::future::Future; use std::hash::{ DefaultHasher, @@ -103,7 +106,7 @@ pub enum GetPromptError { /// Messages used for communication between the tool initialization thread and the loading /// display thread. These messages control the visual loading indicators shown to /// the user during tool initialization. -pub enum LoadingMsg { +enum LoadingMsg { /// Indicates a new tool is being initialized and should be added to the loading /// display. The String parameter is the name of the tool being initialized. Add(String), @@ -421,7 +424,7 @@ impl ToolManagerBuilder { .insert(server_name, (sanitized_mapping, specs)); // We only want to set this flag when the display task has ended if load_msg_sender.is_none() { - has_new_stuff_clone.store(true, Ordering::Relaxed); + has_new_stuff_clone.store(true, Ordering::Release); } }, UpdateEventMessage::PromptsListResult { @@ -710,28 +713,8 @@ impl ToolManager { } } } - let new_tools = { - let mut new_tool_specs = self.new_tool_specs.lock().await; - new_tool_specs.drain().fold(HashMap::new(), |mut acc, (k, v)| { - acc.insert(k, v); - acc - }) - }; - for (_server_name, (tool_name_map, specs)) in new_tools { - for (k, v) in tool_name_map { - self.tn_map.insert(k, v); - } - for spec in specs { - tool_specs.insert(spec.name.clone(), spec); - } - } - // caching the tool names for skim operations - for tool_name in tool_specs.keys() { - if !self.tn_map.contains_key(tool_name) { - self.tn_map.insert(tool_name.clone(), tool_name.clone()); - } - } - self.schema = tool_specs.clone(); + self.update().await; + tool_specs.extend(self.schema.clone()); Ok(tool_specs) } @@ -831,6 +814,49 @@ impl ToolManager { }) } + /// Updates tool managers various states with new information + pub async fn update(&mut self) { + // A hashmap of + let mut tool_specs = HashMap::::new(); + let new_tools = { + let mut new_tool_specs = self.new_tool_specs.lock().await; + new_tool_specs.drain().fold(HashMap::new(), |mut acc, (k, v)| { + acc.insert(k, v); + acc + }) + }; + let mut updated_servers = HashSet::::new(); + for (_server_name, (tool_name_map, specs)) in new_tools { + // In a populated tn map (i.e. a partially initialized or outdated fleet of servers) there + // will be incoming tools with names that are already in the tn map, we will be writing + // over them (perhaps with the same information that they already had), and that's okay. + // In an event where a server has removed tools, the tools that are no longer available + // will linger in this map. This is also okay to not clean up as it does not affect the + // look up of tool names that are still active. + for (k, v) in tool_name_map { + self.tn_map.insert(k, v); + } + if let Some(spec) = specs.first() { + updated_servers.insert(spec.tool_origin.clone()); + } + for spec in specs { + tool_specs.insert(spec.name.clone(), spec); + } + } + // Caching the tool names for skim operations + for tool_name in tool_specs.keys() { + if !self.tn_map.contains_key(tool_name) { + self.tn_map.insert(tool_name.clone(), tool_name.clone()); + } + } + // Update schema + // As we are writing over the ensemble of tools in a given server, we will need to first + // remove everything that it has. + self.schema + .retain(|_tool_name, spec| !updated_servers.contains(&spec.tool_origin)); + self.schema.extend(tool_specs); + } + #[allow(clippy::await_holding_lock)] pub async fn get_prompt(&self, get_command: PromptsGetCommand) -> Result { let (server_name, prompt_name) = match get_command.params.name.split_once('/') { diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index e558e10bea..0e0dafc101 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -126,30 +126,33 @@ pub struct ToolPermission { /// Tools that do not have an associated ToolPermission should use /// their default logic to determine to permission. pub struct ToolPermissions { + // We need this field for any stragglers + pub trust_all: bool, pub permissions: HashMap, } impl ToolPermissions { pub fn new(capacity: usize) -> Self { Self { + trust_all: false, permissions: HashMap::with_capacity(capacity), } } pub fn is_trusted(&self, tool_name: &str) -> bool { - self.permissions.get(tool_name).is_some_and(|perm| perm.trusted) + self.trust_all || self.permissions.get(tool_name).is_some_and(|perm| perm.trusted) } /// Returns a label to describe the permission status for a given tool. pub fn display_label(&self, tool_name: &str) -> String { - if self.has(tool_name) { + if self.has(tool_name) || self.trust_all { if self.is_trusted(tool_name) { format!(" {}", "trusted".dark_green().bold()) } else { format!(" {}", "not trusted".dark_grey()) } } else { - Self::default_permission_label(tool_name) + self.default_permission_label(tool_name) } } @@ -159,15 +162,18 @@ impl ToolPermissions { } pub fn untrust_tool(&mut self, tool_name: &str) { + self.trust_all = false; self.permissions .insert(tool_name.to_string(), ToolPermission { trusted: false }); } pub fn reset(&mut self) { + self.trust_all = false; self.permissions.clear(); } pub fn reset_tool(&mut self, tool_name: &str) { + self.trust_all = false; self.permissions.remove(tool_name); } @@ -178,7 +184,7 @@ impl ToolPermissions { /// Provide default permission labels for the built-in set of tools. /// Unknown tools are assumed to be "Per-request" // This "static" way avoids needing to construct a tool instance. - fn default_permission_label(tool_name: &str) -> String { + fn default_permission_label(&self, tool_name: &str) -> String { let label = match tool_name { "fs_read" => "trusted".dark_green().bold(), "fs_write" => "not trusted".dark_grey(), @@ -186,6 +192,7 @@ impl ToolPermissions { "use_aws" => "trust read-only commands".dark_grey(), "report_issue" => "trusted".dark_green().bold(), "thinking" => "trusted (prerelease)".dark_green().bold(), + _ if self.trust_all => "trusted".dark_grey().bold(), _ => "not trusted".dark_grey(), }; From 2391800118c2b37dc241b1ae5bebda3cf814315f Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Thu, 8 May 2025 19:01:41 -0700 Subject: [PATCH 13/26] enables list changed for prompts and tools --- crates/chat-cli/src/lib.rs | 1 + crates/chat-cli/src/mcp_client/client.rs | 329 +++++++++++------- crates/chat-cli/src/mcp_client/server.rs | 43 ++- .../chat-cli/test_mcp_server/test_server.rs | 56 ++- 4 files changed, 252 insertions(+), 177 deletions(-) diff --git a/crates/chat-cli/src/lib.rs b/crates/chat-cli/src/lib.rs index 2b584b4c47..d8bfa3209c 100644 --- a/crates/chat-cli/src/lib.rs +++ b/crates/chat-cli/src/lib.rs @@ -1,3 +1,4 @@ +#![cfg(not(test))] //! This lib.rs is only here for testing purposes. //! `test_mcp_server/test_server.rs` is declared as a separate binary and would need a way to //! reference types defined inside of this crate, hence the export. diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 978edc8d5d..9012f32e1a 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -126,6 +126,7 @@ pub struct Client { client_info: serde_json::Value, current_id: Arc, pub messenger: Option>, + // TODO: move this to tool manager that way all the assets are treated equally pub prompt_gets: Arc>>, pub is_prompts_out_of_date: Arc, } @@ -223,60 +224,6 @@ where let transport_ref = self.transport.clone(); let server_name = self.server_name.clone(); - tokio::spawn(async move { - let mut listener = transport_ref.get_listener(); - loop { - match listener.recv().await { - Ok(msg) => { - match msg { - JsonRpcMessage::Request(_req) => {}, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { method, params, .. } = notif; - if method.as_str() == "notifications/message" || method.as_str() == "message" { - let level = params - .as_ref() - .and_then(|p| p.get("level")) - .and_then(|v| serde_json::to_string(v).ok()); - let data = params - .as_ref() - .and_then(|p| p.get("data")) - .and_then(|v| serde_json::to_string(v).ok()); - if let (Some(level), Some(data)) = (level, data) { - match level.to_lowercase().as_str() { - "error" => { - tracing::error!(target: "mcp", "{}: {}", server_name, data); - }, - "warn" => { - tracing::warn!(target: "mcp", "{}: {}", server_name, data); - }, - "info" => { - tracing::info!(target: "mcp", "{}: {}", server_name, data); - }, - "debug" => { - tracing::debug!(target: "mcp", "{}: {}", server_name, data); - }, - "trace" => { - tracing::trace!(target: "mcp", "{}: {}", server_name, data); - }, - _ => {}, - } - } - } - }, - JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ - }, - } - }, - Err(e) => { - tracing::error!("Background listening thread for client {}: {:?}", server_name, e); - }, - } - } - }); - - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - // Spawning a task to listen and log stderr output tokio::spawn(async move { let mut log_listener = transport_ref.get_log_listener(); @@ -329,87 +276,95 @@ where if cap.prompts.is_some() { self.is_prompts_out_of_date.store(true, Ordering::Relaxed); let client_ref = (*self).clone(); + let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); tokio::spawn(async move { - let Ok(resp) = client_ref.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client_ref.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); - return; - }; - let Some(prompts) = result.get("prompts") else { - tracing::warn!( - "Prompt list query result contained no field named prompts for {0}", - client_ref.server_name - ); - return; - }; - let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { - tracing::error!( - "Prompt list query deserialization failed for {0}", - client_ref.server_name - ); - return; - }; - let Ok(mut lock) = client_ref.prompt_gets.write() else { - tracing::error!( - "Failed to obtain write lock for prompt list query for {0}", - client_ref.server_name - ); - return; - }; - for prompt in prompts { - let name = prompt.name.clone(); - lock.insert(name, prompt); - } + fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; }); } - if let (Some(_), Some(messenger)) = (&cap.tools, &self.messenger) { - tracing::error!( - "## background: {} is spawning background task to fetch tools", - self.server_name - ); + if cap.tools.is_some() { let client_ref = (*self).clone(); - let msger = messenger.duplicate(); + let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); tokio::spawn(async move { - // TODO: decouple pagination logic from request and have page fetching logic here - // instead - let resp = match client_ref.request("tools/list", None).await { - Ok(resp) => resp, - Err(e) => { - tracing::error!("Failed to retrieve tool list from {}: {:?}", client_ref.server_name, e); - return; + fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; + }); + } + + let transport_ref = self.transport.clone(); + let server_name = self.server_name.clone(); + let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); + let client_ref = (*self).clone(); + + let prompts_list_changed_supported = cap.prompts.as_ref().is_some_and(|p| p.get("listChanged").is_some()); + let tools_list_changed_supported = cap.tools.as_ref().is_some_and(|t| t.get("listChanged").is_some()); + tokio::spawn(async move { + let mut listener = transport_ref.get_listener(); + loop { + match listener.recv().await { + Ok(msg) => { + match msg { + JsonRpcMessage::Request(_req) => {}, + JsonRpcMessage::Notification(notif) => { + let JsonRpcNotification { method, params, .. } = notif; + match method.as_str() { + "notifications/message" | "message" => { + let level = params + .as_ref() + .and_then(|p| p.get("level")) + .and_then(|v| serde_json::to_string(v).ok()); + let data = params + .as_ref() + .and_then(|p| p.get("data")) + .and_then(|v| serde_json::to_string(v).ok()); + if let (Some(level), Some(data)) = (level, data) { + match level.to_lowercase().as_str() { + "error" => { + tracing::error!(target: "mcp", "{}: {}", server_name, data); + }, + "warn" => { + tracing::warn!(target: "mcp", "{}: {}", server_name, data); + }, + "info" => { + tracing::info!(target: "mcp", "{}: {}", server_name, data); + }, + "debug" => { + tracing::debug!(target: "mcp", "{}: {}", server_name, data); + }, + "trace" => { + tracing::trace!(target: "mcp", "{}: {}", server_name, data); + }, + _ => {}, + } + } + }, + "notifications/prompts/list_changed" | "prompts/list_changed" + if prompts_list_changed_supported => + { + // TODO: after we have moved the prompts to the tool + // manager we follow the same workflow as the list changed + // for tools + fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) + .await; + client_ref.is_prompts_out_of_date.store(true, Ordering::Release); + }, + "notifications/tools/list_changed" | "tools/list_changed" + if tools_list_changed_supported => + { + fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) + .await; + }, + _ => {}, + } + }, + JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ + }, + } }, - }; - if let Some(error) = resp.error { - let msg = format!( - "Failed to retrieve tool list for {}: {:?}", - client_ref.server_name, error - ); - tracing::error!("{}", &msg); - return; - } - let Some(result) = resp.result else { - tracing::error!("Tool list response from {} is missing result", client_ref.server_name); - return; - }; - let tool_list_result = match serde_json::from_value::(result) { - Ok(result) => result, Err(e) => { - tracing::error!( - "Failed to deserialize tool result from {}: {:?}", - client_ref.server_name, - e - ); - return; + tracing::error!("Background listening thread for client {}: {:?}", server_name, e); }, - }; - if let Err(e) = msger.send_tools_list_result(tool_list_result).await { - tracing::error!("Failed to send tool result through messenger {:?}", e); } - }); - } + } + }); Ok(cap) } @@ -569,6 +524,85 @@ fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientEr Ok(()) } +// TODO: after we move prompts to tool manager, use the messenger to notify the listener spawned by +// tool manager to update its own field. Currently this function does not make use of the +// messesnger. +#[allow(clippy::borrowed_box)] +async fn fetch_prompts_and_notify_with_messenger(client: &Client, _messenger: Option<&Box>) +where + T: Transport, +{ + let Ok(resp) = client.request("prompts/list", None).await else { + tracing::error!("Prompt list query failed for {0}", client.server_name); + return; + }; + let Some(result) = resp.result else { + tracing::warn!("Prompt list query returned no result for {0}", client.server_name); + return; + }; + let Some(prompts) = result.get("prompts") else { + tracing::warn!( + "Prompt list query result contained no field named prompts for {0}", + client.server_name + ); + return; + }; + let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { + tracing::error!("Prompt list query deserialization failed for {0}", client.server_name); + return; + }; + let Ok(mut lock) = client.prompt_gets.write() else { + tracing::error!( + "Failed to obtain write lock for prompt list query for {0}", + client.server_name + ); + return; + }; + lock.clear(); + for prompt in prompts { + let name = prompt.name.clone(); + lock.insert(name, prompt); + } +} + +#[allow(clippy::borrowed_box)] +async fn fetch_tools_and_notify_with_messenger(client: &Client, messenger: Option<&Box>) +where + T: Transport, +{ + // TODO: decouple pagination logic from request and have page fetching logic here + // instead + let resp = match client.request("tools/list", None).await { + Ok(resp) => resp, + Err(e) => { + tracing::error!("Failed to retrieve tool list from {}: {:?}", client.server_name, e); + return; + }, + }; + if let Some(error) = resp.error { + let msg = format!("Failed to retrieve tool list for {}: {:?}", client.server_name, error); + tracing::error!("{}", &msg); + return; + } + let Some(result) = resp.result else { + tracing::error!("Tool list response from {} is missing result", client.server_name); + return; + }; + let tool_list_result = match serde_json::from_value::(result) { + Ok(result) => result, + Err(e) => { + tracing::error!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); + return; + }, + }; + if let Some(messenger) = messenger { + let _ = messenger + .send_tools_list_result(tool_list_result) + .await + .map_err(|e| tracing::error!("Failed to send tool result through messenger {:?}", e)); + } +} + #[cfg(test)] mod tests { use std::path::PathBuf; @@ -647,11 +681,11 @@ mod tests { let (res_one, res_two) = tokio::join!( time::timeout( - time::Duration::from_secs(5), + time::Duration::from_secs(10), test_client_routine(&mut client_one, serde_json::json!(client_one_cap)) ), time::timeout( - time::Duration::from_secs(5), + time::Duration::from_secs(10), test_client_routine(&mut client_two, serde_json::json!(client_two_cap)) ) ); @@ -661,6 +695,7 @@ mod tests { assert!(res_two.is_ok()); } + #[allow(clippy::await_holding_lock)] async fn test_client_routine( client: &mut Client, cap_sent: serde_json::Value, @@ -736,6 +771,7 @@ mod tests { .await .expect("Mock prompt prep failed"); let prompts_recvd = client.request("prompts/list", None).await.expect("List prompts failed"); + client.is_prompts_out_of_date.store(false, Ordering::Release); assert!(are_json_values_equal( prompts_recvd .result @@ -745,6 +781,41 @@ mod tests { &mock_prompts_for_verify )); + // Test prompts list changed + let fake_prompt_names = ["code_review_four", "code_review_five", "code_review_six"]; + let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); + let mock_prompts_prep_param = mock_result_prompts + .iter() + .zip(fake_prompt_names.iter()) + .map(|(v, n)| { + serde_json::json!({ + "key": (*n).to_string(), + "value": v + }) + }) + .collect::>(); + let mock_prompts_prep_param = + serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); + let _ = client + .request("store_mock_prompts", Some(mock_prompts_prep_param)) + .await + .expect("Mock new prompt request failed"); + // After we send the signal for the server to clear prompts, we should be receiving signal + // to fetch for new prompts, after which we should be getting no prompts. + let is_prompts_out_of_date = client.is_prompts_out_of_date.clone(); + let wait_for_new_prompts = async move { + while !is_prompts_out_of_date.load(Ordering::Acquire) { + tokio::time::sleep(time::Duration::from_millis(100)).await; + } + }; + time::timeout(time::Duration::from_secs(5), wait_for_new_prompts) + .await + .expect("Timed out while waiting for new prompts"); + let new_prompts = client.prompt_gets.read().expect("Failed to read new prompts"); + for k in new_prompts.keys() { + assert!(fake_prompt_names.contains(&k.as_str())); + } + // Test env var inclusion let env_vars = client.request("get_env_vars", None).await.expect("Get env vars failed"); let env_one = env_vars @@ -764,8 +835,6 @@ mod tests { assert_eq!(env_one_as_str, "\"1\"".to_string()); assert_eq!(env_two_as_str, "\"2\"".to_string()); - let shutdown_result = client.shutdown().await; - assert!(shutdown_result.is_ok()); Ok(()) } diff --git a/crates/chat-cli/src/mcp_client/server.rs b/crates/chat-cli/src/mcp_client/server.rs index 73dfae2dd3..7b320a2c6e 100644 --- a/crates/chat-cli/src/mcp_client/server.rs +++ b/crates/chat-cli/src/mcp_client/server.rs @@ -109,20 +109,37 @@ where let current_id_clone = current_id.clone(); let request_sender = move |method: &str, params: Option| -> Result<(), ServerError> { let id = current_id_clone.fetch_add(1, Ordering::SeqCst); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, + let msg = match method.split_once("/") { + Some(("request", _)) => { + let request = JsonRpcRequest { + jsonrpc: JsonRpcVersion::default(), + id, + method: method.to_owned(), + params, + }; + let msg = JsonRpcMessage::Request(request.clone()); + #[allow(clippy::map_err_ignore)] + let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; + pending_request.insert(id, request); + Some(msg) + }, + Some(("notifications", _)) => { + let notif = JsonRpcNotification { + jsonrpc: JsonRpcVersion::default(), + method: method.to_owned(), + params, + }; + let msg = JsonRpcMessage::Notification(notif); + Some(msg) + }, + _ => None, }; - let msg = JsonRpcMessage::Request(request.clone()); - let transport = transport_clone.clone(); - tokio::task::spawn(async move { - let _ = transport.send(&msg).await; - }); - #[allow(clippy::map_err_ignore)] - let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; - pending_request.insert(id, request); + if let Some(msg) = msg { + let transport = transport_clone.clone(); + tokio::task::spawn(async move { + let _ = transport.send(&msg).await; + }); + } Ok(()) }; handler.register_send_request_callback(request_sender); diff --git a/crates/chat-cli/test_mcp_server/test_server.rs b/crates/chat-cli/test_mcp_server/test_server.rs index 46c2f235ea..970157f96b 100644 --- a/crates/chat-cli/test_mcp_server/test_server.rs +++ b/crates/chat-cli/test_mcp_server/test_server.rs @@ -167,24 +167,17 @@ impl ServerRequestHandler for Handler { return Ok(None); } } else { - let first_key = self - .tool_spec_key_list - .lock() - .await + let tool_spec_key_list = self.tool_spec_key_list.lock().await; + let tool_spec = self.tool_spec.lock().await; + let first_key = tool_spec_key_list .first() .expect("First key missing from tool specs") .clone(); - let first_value = self - .tool_spec - .lock() - .await + let first_value = tool_spec .get(&first_key) .expect("First value missing from tool specs") .clone(); - let second_key = self - .tool_spec_key_list - .lock() - .await + let second_key = tool_spec_key_list .get(1) .expect("Second key missing from tool specs") .clone(); @@ -241,8 +234,11 @@ impl ServerRequestHandler for Handler { eprintln!("Failed to convert to mock specs from value"); return Ok(None); }; - let self_prompts = self.prompts.lock().await; + let mut self_prompts = self.prompts.lock().await; let mut self_prompt_key_list = self.prompt_key_list.lock().await; + let is_first_mock = self_prompts.is_empty(); + self_prompts.clear(); + self_prompt_key_list.clear(); let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| { let Some(key) = spec.get("key").cloned() else { return acc; @@ -255,9 +251,16 @@ impl ServerRequestHandler for Handler { acc.insert(key, spec.get("value").cloned()); acc }); + if !is_first_mock { + if let Some(sender) = &self.send_request { + let _ = sender("notifications/prompts/list_changed", None); + } + } Ok(None) }, "prompts/list" => { + // We expect this method to be called after the mock prompts have already been + // stored. self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed); if let Some(params) = params { if let Some(cursor) = params.get("cursor").cloned() { @@ -295,27 +298,12 @@ impl ServerRequestHandler for Handler { return Ok(None); } } else { - let first_key = self - .prompt_key_list - .lock() - .await - .first() - .expect("First key missing from prompts") - .clone(); - let first_value = self - .prompts - .lock() - .await - .get(&first_key) - .expect("First value missing from prompts") - .clone(); - let second_key = self - .prompt_key_list - .lock() - .await - .get(1) - .expect("Second key missing from prompts") - .clone(); + // If there is no parameter, this is the request to retrieve the first page + let prompt_key_list = self.prompt_key_list.lock().await; + let prompts = self.prompts.lock().await; + let first_key = prompt_key_list.first().expect("first key missing"); + let first_value = prompts.get(first_key).cloned().unwrap().unwrap(); + let second_key = prompt_key_list.get(1).expect("second key missing"); return Ok(Some(serde_json::json!({ "prompts": [first_value], "nextCursor": second_key From 9de421adfc8106320587910790c188cfe29e5a99 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Fri, 9 May 2025 19:24:27 -0700 Subject: [PATCH 14/26] adds copy change to server loading task --- .../src/cli/chat/conversation_state.rs | 1 + crates/chat-cli/src/cli/chat/mod.rs | 12 +- crates/chat-cli/src/cli/chat/tool_manager.rs | 129 +++++++++++++----- .../src/cli/chat/tools/custom_tool.rs | 12 +- 4 files changed, 115 insertions(+), 39 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index 20b0c64f1b..1e06c108a8 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -322,6 +322,7 @@ impl ConversationState { let tool_name = tool_use.name.as_str(); if !tool_name_list.contains(&tool_name) { tool_use.name = DUMMY_TOOL_NAME.to_string(); + tool_use.args = serde_json::json!({}); } }) .collect::>(); diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 8423c7aed8..db13180574 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -86,7 +86,10 @@ use rand::distr::{ SampleString, }; use tokio::signal::ctrl_c; -use util::shared_writer::SharedWriter; +use util::shared_writer::{ + NullWriter, + SharedWriter, +}; use util::ui::draw_box; use crate::api_client::StreamingClient; @@ -396,7 +399,12 @@ pub async fn chat( .prompt_list_sender(prompt_response_sender) .prompt_list_receiver(prompt_request_receiver) .conversation_id(&conversation_id) - .build() + .interactive(interactive) + .build(if interactive { + Box::new(output.clone()) + } else { + Box::new(NullWriter {}) + }) .await?; let tool_config = tool_manager.load_tools().await?; let mut tool_permissions = ToolPermissions::new(tool_config.len()); diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 00dd2176ac..f7a5dceeea 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -205,6 +205,7 @@ pub struct ToolManagerBuilder { prompt_list_sender: Option>>, prompt_list_receiver: Option>>, conversation_id: Option, + is_interactive: bool, } impl ToolManagerBuilder { @@ -228,12 +229,19 @@ impl ToolManagerBuilder { self } - pub async fn build(mut self) -> eyre::Result { + #[allow(dead_code)] + pub fn interactive(mut self, is_interactive: bool) -> Self { + self.is_interactive = is_interactive; + self + } + + pub async fn build(mut self, mut output: Box) -> eyre::Result { let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; debug_assert!(self.conversation_id.is_some()); let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; let regex = regex::Regex::new(VALID_TOOL_NAME)?; let mut hasher = DefaultHasher::new(); + let is_interactive = self.is_interactive; let pre_initialized = mcp_servers .into_iter() .map(|(server_name, server_config)| { @@ -246,11 +254,7 @@ impl ToolManagerBuilder { // Send up task to update user on server loading status let (tx, rx) = std::sync::mpsc::channel::(); - // Using a hand rolled thread because it's just easier to do this than do deal with the Send - // requirements that comes with holding onto the stdout lock. let loading_display_task = tokio::task::spawn_blocking(move || { - let stdout = std::io::stdout(); - let mut stdout_lock = stdout.lock(); let mut loading_servers = HashMap::::new(); let mut spinner_logo_idx: usize = 0; let mut complete: usize = 0; @@ -262,16 +266,16 @@ impl ToolManagerBuilder { let init_time = std::time::Instant::now(); let is_done = false; let status_line = StatusLine { init_time, is_done }; - execute!(stdout_lock, cursor::MoveToColumn(0))?; + execute!(output, cursor::MoveToColumn(0))?; if !loading_servers.is_empty() { // TODO: account for terminal width - execute!(stdout_lock, cursor::MoveUp(1))?; + execute!(output, cursor::MoveUp(1))?; } loading_servers.insert(name.clone(), status_line); let total = loading_servers.len(); - execute!(stdout_lock, terminal::Clear(terminal::ClearType::CurrentLine))?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; + execute!(output, terminal::Clear(terminal::ClearType::CurrentLine))?; + queue_init_message(spinner_logo_idx, complete, failed, total, is_interactive, &mut output)?; + output.flush()?; }, LoadingMsg::Done(name) => { if let Some(status_line) = loading_servers.get_mut(&name) { @@ -281,15 +285,22 @@ impl ToolManagerBuilder { (std::time::Instant::now() - status_line.init_time).as_secs_f64().abs(); let time_taken = format!("{:.2}", time_taken); execute!( - stdout_lock, + output, cursor::MoveToColumn(0), cursor::MoveUp(1), terminal::Clear(terminal::ClearType::CurrentLine), )?; - queue_success_message(&name, &time_taken, &mut stdout_lock)?; + queue_success_message(&name, &time_taken, &mut output)?; let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; + queue_init_message( + spinner_logo_idx, + complete, + failed, + total, + is_interactive, + &mut output, + )?; + output.flush()?; } if loading_servers.iter().all(|(_, status)| status.is_done) { break; @@ -300,14 +311,21 @@ impl ToolManagerBuilder { status_line.is_done = true; failed += 1; execute!( - stdout_lock, + output, cursor::MoveToColumn(0), cursor::MoveUp(1), terminal::Clear(terminal::ClearType::CurrentLine), )?; - queue_failure_message(&name, &msg, &mut stdout_lock)?; + queue_failure_message(&name, &msg, &mut output)?; let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + queue_init_message( + spinner_logo_idx, + complete, + failed, + total, + is_interactive, + &mut output, + )?; } if loading_servers.iter().all(|(_, status)| status.is_done) { break; @@ -318,16 +336,23 @@ impl ToolManagerBuilder { status_line.is_done = true; complete += 1; execute!( - stdout_lock, + output, cursor::MoveToColumn(0), cursor::MoveUp(1), terminal::Clear(terminal::ClearType::CurrentLine), )?; let msg = eyre::eyre!(msg.to_string()); - queue_warn_message(&name, &msg, &mut stdout_lock)?; + queue_warn_message(&name, &msg, &mut output)?; let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; + queue_init_message( + spinner_logo_idx, + complete, + failed, + total, + is_interactive, + &mut output, + )?; + output.flush()?; } if loading_servers.iter().all(|(_, status)| status.is_done) { break; @@ -336,7 +361,7 @@ impl ToolManagerBuilder { LoadingMsg::Terminate => { if loading_servers.iter().any(|(_, status)| !status.is_done) { execute!( - stdout_lock, + output, cursor::MoveToColumn(0), cursor::MoveUp(1), terminal::Clear(terminal::ClearType::CurrentLine), @@ -351,8 +376,9 @@ impl ToolManagerBuilder { acc }); let msg = eyre::eyre!(msg); - queue_incomplete_load_message(&msg, &mut stdout_lock)?; - stdout_lock.flush()?; + let total = loading_servers.len(); + queue_incomplete_load_message(complete, total, &msg, &mut output)?; + output.flush()?; } break; }, @@ -360,7 +386,7 @@ impl ToolManagerBuilder { Err(RecvTimeoutError::Timeout) => { spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); execute!( - stdout_lock, + output, cursor::SavePosition, cursor::MoveToColumn(0), cursor::MoveUp(1), @@ -570,6 +596,7 @@ impl ToolManagerBuilder { loading_status_sender, new_tool_specs, has_new_stuff, + is_interactive, ..Default::default() }) } @@ -640,6 +667,8 @@ pub struct ToolManager { /// This is mainly used to show the user what the tools look like from the perspective of the /// model. pub schema: HashMap, + + is_interactive: bool, } impl Clone for ToolManager { @@ -652,6 +681,7 @@ impl Clone for ToolManager { prompts: self.prompts.clone(), tn_map: self.tn_map.clone(), schema: self.schema.clone(), + is_interactive: self.is_interactive, ..Default::default() } } @@ -708,8 +738,12 @@ impl ToolManager { } }, _ = ctrl_c() => { - if let Some(tx) = tx { - let _ = tx.send(LoadingMsg::Terminate); + if self.is_interactive { + if let Some(tx) = tx { + let _ = tx.send(LoadingMsg::Terminate); + } + } else { + return Err(eyre::eyre!("User interrupted mcp server loading in non-interactive mode. Ending.")); } } } @@ -1186,6 +1220,7 @@ fn queue_init_message( complete: usize, failed: usize, total: usize, + is_interactive: bool, output: &mut impl Write, ) -> eyre::Result<()> { if total == complete { @@ -1205,7 +1240,7 @@ fn queue_init_message( } else { queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?; } - Ok(queue!( + queue!( output, style::SetForegroundColor(style::Color::Blue), style::Print(format!(" {}", complete)), @@ -1214,11 +1249,22 @@ fn queue_init_message( style::SetForegroundColor(style::Color::Blue), style::Print(format!("{} ", total)), style::ResetColor, - style::Print("mcp servers initialized. Press ctrl-c to load the remaining servers in the background\n"), - )?) + style::Print("mcp servers initialized."), + )?; + if is_interactive { + queue!( + output, + style::SetForegroundColor(style::Color::Blue), + style::Print(" ctrl-c "), + style::ResetColor, + style::Print("to start chatting now") + )?; + } + Ok(queue!(output, style::Print("\n"))?) } fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + use crate::util::CHAT_BINARY_NAME; Ok(queue!( output, style::SetForegroundColor(style::Color::Red), @@ -1229,7 +1275,9 @@ fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut style::Print(" has failed to load:\n- "), style::Print(fail_load_msg), style::Print("\n"), - style::Print("- run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog for detail\n"), + style::Print(format!( + "- run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n" + )), style::ResetColor, )?) } @@ -1248,14 +1296,27 @@ fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) - )?) } -fn queue_incomplete_load_message(msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { +fn queue_incomplete_load_message( + complete: usize, + total: usize, + msg: &eyre::Report, + output: &mut impl Write, +) -> eyre::Result<()> { Ok(queue!( output, style::SetForegroundColor(style::Color::Yellow), - style::Print("⚠ "), + style::Print("⚠"), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!(" {}", complete)), + style::ResetColor, + style::Print(" of "), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!("{} ", total)), + style::ResetColor, + style::Print("mcp servers initialized."), style::ResetColor, // We expect the message start with a newline - style::Print("following servers are still loading:"), + style::Print(" Servers still loading:"), style::Print(msg), style::ResetColor, )?) diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 24f6d0f364..dae155078e 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -170,9 +170,15 @@ impl CustomTool { pub async fn invoke(&self, _ctx: &Context, _updates: &mut impl Write) -> Result { // Assuming a response shape as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools let resp = self.client.request(self.method.as_str(), self.params.clone()).await?; - let result = resp - .result - .ok_or(eyre::eyre!("{} invocation failed to produce a result", self.name))?; + let result = match resp.result { + Some(result) => result, + None => { + let failure = resp.error.map_or("Unknown error encountered".to_string(), |err| { + serde_json::to_string(&err).unwrap_or_default() + }); + return Err(eyre::eyre!(failure)); + }, + }; match serde_json::from_value::(result.clone()) { Ok(mut de_result) => { From e05fe3ec723d40ba45e5f5951e11008607f9cc2c Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Sat, 10 May 2025 16:24:01 -0700 Subject: [PATCH 15/26] makes server init timeout configurable --- crates/chat-cli/src/cli/chat/tool_manager.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index f7a5dceeea..eb8d2ce035 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -229,7 +229,6 @@ impl ToolManagerBuilder { self } - #[allow(dead_code)] pub fn interactive(mut self, is_interactive: bool) -> Self { self.is_interactive = is_interactive; self @@ -715,7 +714,7 @@ impl ToolManager { }); // We need to cast it to erase the type otherwise the compiler will default to static // dispatch, which would result in an error of inconsistent match arm return type. - let display_future: Pin>> = match display_task { + let display_fut: Pin>> = match display_task { Some(display_task) => { let fut = async move { if let Err(e) = display_task.await { @@ -729,10 +728,19 @@ impl ToolManager { Box::pin(fut) }, }; + // TODO: make this timeout configurable + let timeout_fut: Pin>> = if self.is_interactive { + let init_timeout = crate::settings::settings::get_int("mcp.initTimeout") + .map_or(5000_u64, |s| s.map_or(5000_u64, |n| n as u64)); + error!("## timeout: {init_timeout}"); + Box::pin(tokio::time::sleep(std::time::Duration::from_millis(init_timeout))) + } else { + let fut = async { future::pending::<()>().await }; + Box::pin(fut) + }; tokio::select! { - _ = display_future => {}, - // TODO: make this timeout configurable - _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => { + _ = display_fut => {}, + _ = timeout_fut => { if let Some(tx) = tx { let _ = tx.send(LoadingMsg::Terminate); } From af0235e530b887eb73b2ce8e86e38f36d83467e9 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Sat, 10 May 2025 17:02:06 -0700 Subject: [PATCH 16/26] uses tn map keys as list of tools for dummy substitute --- crates/chat-cli/src/cli/chat/conversation_state.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index 1e06c108a8..7dea7bd960 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -306,12 +306,11 @@ impl ConversationState { // do this if the last message is a tool call that has failed. let tool_use_results = user_msg.tool_use_results(); if let Some(tool_use_results) = tool_use_results { - let tool_name_list = self - .tools - .values() - .flatten() - .map(|Tool::ToolSpecification(spec)| spec.name.as_str()) - .collect::>(); + // Note that we need to use the keys in tool manager's tn_map as the keys are the + // actual tool names as exposed to the model and the backend. If we use the actual + // names as they are recognized by their respective servers, we risk concluding + // with false positives. + let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::>(); for result in tool_use_results { if let ToolResultStatus::Error = result.status { let tool_use_id = result.tool_use_id.as_str(); From 0808038a5fbc147b7576aecc6b93bff7182e9f10 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Sat, 10 May 2025 17:11:45 -0700 Subject: [PATCH 17/26] adds tip for background loading and init timeout --- crates/chat-cli/src/cli/chat/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index db13180574..58cdf3ec49 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -210,7 +210,7 @@ const SMALL_SCREEN_WECLOME_TEXT: &str = color_print::cstr! {" Welcome to Amazon Q! "}; -const ROTATING_TIPS: [&str; 9] = [ +const ROTATING_TIPS: [&str; 10] = [ color_print::cstr! {"Get notified whenever Q CLI finishes responding. Just run q settings chat.enableNotifications true"}, color_print::cstr! {"You can use /editor to edit your prompt with a vim-like experience"}, color_print::cstr! {"You can execute bash commands by typing ! followed by the command"}, @@ -220,6 +220,7 @@ const ROTATING_TIPS: [&str; 9] = [ color_print::cstr! {"/usage shows you a visual breakdown of your current context window usage"}, color_print::cstr! {"If you want to file an issue to the Q CLI team, just tell me, or run q issue"}, color_print::cstr! {"You can enable custom tools with MCP servers. Learn more with /help"}, + color_print::cstr! {"You can specify wait time (in ms) for mcp server loading with q settings mcp.initTimeout {timeout in int}. Servers that takes longer than the specified time will continue to load in the background. Use /tools to see pending servers."}, ]; const GREETING_BREAK_POINT: usize = 67; From a00d22a54ce18fecb297bbfa57acaaf57168f67f Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Sat, 10 May 2025 21:03:15 -0700 Subject: [PATCH 18/26] updates tools info per try chat loop --- crates/chat-cli/src/cli/chat/mod.rs | 11 ++++++ .../chat-cli/src/cli/chat/server_messenger.rs | 14 +++++++- crates/chat-cli/src/cli/chat/tool_manager.rs | 34 ++++++++++++++++--- .../src/cli/chat/tools/custom_tool.rs | 3 ++ crates/chat-cli/src/mcp_client/messenger.rs | 7 ++++ 5 files changed, 63 insertions(+), 6 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 58cdf3ec49..abb9e33c3e 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -33,6 +33,7 @@ use std::process::{ ExitCode, }; use std::sync::Arc; +use std::sync::atomic::Ordering; use std::time::Duration; use std::{ env, @@ -795,6 +796,16 @@ impl ChatContext { let ctrl_c_stream = ctrl_c(); debug!(?chat_state, "changing to state"); + // Update conversation state with new tool information + if self + .conversation_state + .tool_manager + .has_new_stuff + .load(Ordering::Relaxed) + { + self.conversation_state.update_state().await; + } + let result = match chat_state { ChatState::PromptUser { tool_uses, diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs index e56cbd0715..3adc665d15 100644 --- a/crates/chat-cli/src/cli/chat/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/server_messenger.rs @@ -32,7 +32,9 @@ pub enum UpdateEventMessage { server_name: String, result: ResourceTemplatesListResult, }, - DisplayTaskEnded, + InitStart { + server_name: String, + }, } #[derive(Clone, Debug)] @@ -112,6 +114,16 @@ impl Messenger for ServerMessenger { .map_err(|e| MessengerError::Custom(e.to_string()))?) } + async fn send_init_msg(&self) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::InitStart { + server_name: self.server_name.clone(), + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + fn duplicate(&self) -> Box { Box::new(self.clone()) } diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index eb8d2ce035..253b0ad1a8 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -40,7 +40,10 @@ use serde::{ }; use thiserror::Error; use tokio::signal::ctrl_c; -use tokio::sync::Mutex; +use tokio::sync::{ + Mutex, + RwLock, +}; use tracing::{ error, warn, @@ -406,6 +409,8 @@ impl ToolManagerBuilder { let new_tool_specs_clone = new_tool_specs.clone(); let has_new_stuff = Arc::new(AtomicBool::new(false)); let has_new_stuff_clone = has_new_stuff.clone(); + let pending = Arc::new(RwLock::new(HashSet::::new())); + let pending_clone = pending.clone(); let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); tokio::spawn(async move { while let Some(msg) = msg_rx.recv().await { @@ -415,6 +420,7 @@ impl ToolManagerBuilder { // list calls. match msg { UpdateEventMessage::ToolsListResult { server_name, result } => { + pending_clone.write().await.remove(&server_name); let mut specs = result .tools .into_iter() @@ -464,14 +470,16 @@ impl ToolManagerBuilder { server_name: _, result: _, } => {}, - UpdateEventMessage::DisplayTaskEnded => { - load_msg_sender.take(); + UpdateEventMessage::InitStart { server_name } => { + pending_clone.write().await.insert(server_name.clone()); + if let Some(sender) = &load_msg_sender { + let _ = sender.send(LoadingMsg::Add(server_name)); + } }, } } }); for (mut name, init_res) in pre_initialized { - let _ = tx.send(LoadingMsg::Add(name.clone())); match init_res { Ok(mut client) => { let messenger = messenger_builder.build_with_name(client.get_server_name().to_owned()); @@ -592,6 +600,7 @@ impl ToolManagerBuilder { clients, prompts, loading_display_task, + pending_clients: pending, loading_status_sender, new_tool_specs, has_new_stuff, @@ -638,8 +647,19 @@ pub struct ToolManager { /// These clients are used to communicate with MCP servers. pub clients: HashMap>, + #[allow(dead_code)] + /// A list of client names that are still in the process of being initialized + pub pending_clients: Arc>>, + + /// Flag indicating whether new tool specifications have been added since the last update. + /// When set to true, it signals that the tool manager needs to refresh its internal state + /// to incorporate newly available tools from MCP servers. pub has_new_stuff: Arc, + /// Storage for newly discovered tool specifications from MCP servers that haven't yet been + /// integrated into the main tool registry. This field holds a thread-safe reference to a map + /// of server names to their tool specifications and name mappings, allowing concurrent updates + /// from server initialization processes. new_tool_specs: NewToolSpecs, /// Cache for prompts collected from different servers. @@ -1059,6 +1079,10 @@ impl ToolManager { ); Ok(()) } + + pub async fn _pending_clients(&self) -> Vec { + self.pending_clients.read().await.iter().cloned().collect::>() + } } #[inline] @@ -1150,7 +1174,7 @@ fn process_tool_specs( server_name, msg ); if is_in_display { - Some(LoadingMsg::Error { + Some(LoadingMsg::Warn { name: server_name.to_string(), msg: eyre::eyre!(msg), }) diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index dae155078e..0ac886f519 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -92,6 +92,9 @@ impl CustomToolClient { server_capabilities, .. } => { + if let Some(messenger) = &client.messenger { + let _ = messenger.send_init_msg().await; + } // We'll need to first initialize. This is the handshake every client and server // needs to do before proceeding to anything else let cap = client.init().await?; diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs index 1d4f361445..efd49617ab 100644 --- a/crates/chat-cli/src/mcp_client/messenger.rs +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -33,6 +33,9 @@ pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { result: ResourceTemplatesListResult, ) -> Result<(), MessengerError>; + /// Signals to the orchestrator that a server has started initializing + async fn send_init_msg(&self) -> Result<(), MessengerError>; + /// Creates a duplicate of the messenger object /// This function is used to create a new instance of the messenger with the same configuration fn duplicate(&self) -> Box; @@ -68,6 +71,10 @@ impl Messenger for NullMessenger { Ok(()) } + async fn send_init_msg(&self) -> Result<(), MessengerError> { + Ok(()) + } + fn duplicate(&self) -> Box { Box::new(NullMessenger) } From 0e8243891246fd335355e5a1f0c938ee98486e00 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Sun, 11 May 2025 12:47:01 -0700 Subject: [PATCH 19/26] shows servers still loading in /tools --- crates/chat-cli/src/cli/chat/mod.rs | 15 +++++++++++++++ crates/chat-cli/src/cli/chat/tool_manager.rs | 6 +++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index abb9e33c3e..14a1623abb 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -2362,6 +2362,21 @@ impl ChatContext { ); }); + let loading = self.conversation_state.tool_manager.pending_clients().await; + if !loading.is_empty() { + queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::Print("Servers still loading"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + style::Print("▔".repeat(terminal_width)), + )?; + for client in loading { + queue!(self.output, style::Print(format!(" - {client}")), style::Print("\n"))?; + } + } + queue!( self.output, style::Print("\nTrusted tools can be run without confirmation\n"), diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 253b0ad1a8..97956366bd 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -256,6 +256,10 @@ impl ToolManagerBuilder { // Send up task to update user on server loading status let (tx, rx) = std::sync::mpsc::channel::(); + // TODO: rather than using it as an "anchor" to determine the progress of server loads, we + // should make this task optional (and it is defined as an optional right now. There is + // just no code path with it being None). When ran with no-interactive mode, we really do + // not have a need to run this task. let loading_display_task = tokio::task::spawn_blocking(move || { let mut loading_servers = HashMap::::new(); let mut spinner_logo_idx: usize = 0; @@ -1080,7 +1084,7 @@ impl ToolManager { Ok(()) } - pub async fn _pending_clients(&self) -> Vec { + pub async fn pending_clients(&self) -> Vec { self.pending_clients.read().await.iter().cloned().collect::>() } } From d1d2005069defe0542757bcc9084a4a2545a1fb4 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Sun, 11 May 2025 13:48:44 -0700 Subject: [PATCH 20/26] makes timeout fut resolve immediately for tests --- crates/chat-cli/src/cli/chat/tool_manager.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 97956366bd..655d21b12f 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -747,20 +747,17 @@ impl ToolManager { }; Box::pin(fut) }, - None => { - let fut = async { future::pending::<()>().await }; - Box::pin(fut) - }, + None => Box::pin(future::pending()), }; - // TODO: make this timeout configurable - let timeout_fut: Pin>> = if self.is_interactive { + let timeout_fut: Pin>> = if self.clients.is_empty() { + // If there is no server loaded, we want to resolve immediately + Box::pin(future::ready(())) + } else if self.is_interactive { let init_timeout = crate::settings::settings::get_int("mcp.initTimeout") .map_or(5000_u64, |s| s.map_or(5000_u64, |n| n as u64)); - error!("## timeout: {init_timeout}"); Box::pin(tokio::time::sleep(std::time::Duration::from_millis(init_timeout))) } else { - let fut = async { future::pending::<()>().await }; - Box::pin(fut) + Box::pin(future::pending()) }; tokio::select! { _ = display_fut => {}, From 05ea3f2eb2b0e6e7e05e4f35dcc4fb64cb8c7ede Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Sun, 11 May 2025 17:25:13 -0700 Subject: [PATCH 21/26] refines conversation invariant logic with regards to tool calls with wrong names --- .../src/cli/chat/conversation_state.rs | 50 +++++++++++++------ crates/chat-cli/src/cli/chat/mod.rs | 11 +--- crates/chat-cli/src/cli/chat/tool_manager.rs | 37 +++----------- 3 files changed, 44 insertions(+), 54 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index 7dea7bd960..c56ff229f3 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -302,8 +302,17 @@ impl ConversationState { } // Here we also need to make sure that the tool result corresponds to one of the tools - // in the list. Otherwise we will see validation error from the backend. We would only - // do this if the last message is a tool call that has failed. + // in the list. Otherwise we will see validation error from the backend. There are three + // such circumstances where intervention would be needed: + // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in + // which case we would automatically resolve this tool call to its correct name. This will NOT + // result in an error in its tool result. The intervention here is to substitute the partial name + // with its full name. + // 2. The model had decided to call a tool with its partial name AND there are multiple tools it + // could be referring to, in which case we WILL return an error in the tool result. The + // intervention here is to substitute the ambiguous, partial name with a dummy. + // 3. The model had decided to call a tool that does not exist. The intervention here is to + // substitute the non-existent tool name with a dummy. let tool_use_results = user_msg.tool_use_results(); if let Some(tool_use_results) = tool_use_results { // Note that we need to use the keys in tool manager's tn_map as the keys are the @@ -312,19 +321,30 @@ impl ConversationState { // with false positives. let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::>(); for result in tool_use_results { - if let ToolResultStatus::Error = result.status { - let tool_use_id = result.tool_use_id.as_str(); - let _ = tool_uses - .iter_mut() - .filter(|tool_use| tool_use.id == tool_use_id) - .map(|tool_use| { - let tool_name = tool_use.name.as_str(); - if !tool_name_list.contains(&tool_name) { - tool_use.name = DUMMY_TOOL_NAME.to_string(); - tool_use.args = serde_json::json!({}); - } - }) - .collect::>(); + let tool_use_id = result.tool_use_id.as_str(); + let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id); + if let Some(tool_use) = corresponding_tool_use { + if tool_name_list.contains(&tool_use.name.as_str()) { + // If this tool matches of the tools in our list, this is not our + // concern, error or not. + continue; + } + if let ToolResultStatus::Error = result.status { + // case 2 and 3 + tool_use.name = DUMMY_TOOL_NAME.to_string(); + tool_use.args = serde_json::json!({}); + } else { + // case 1 + let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name)); + // We should be able to find a match but if not we'll just treat it as + // a dummy and move on + if let Some(full_name) = full_name { + tool_use.name = (*full_name).to_string(); + } else { + tool_use.name = DUMMY_TOOL_NAME.to_string(); + tool_use.args = serde_json::json!({}); + } + } } } } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 14a1623abb..95908d2f66 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -33,7 +33,6 @@ use std::process::{ ExitCode, }; use std::sync::Arc; -use std::sync::atomic::Ordering; use std::time::Duration; use std::{ env, @@ -409,6 +408,7 @@ pub async fn chat( }) .await?; let tool_config = tool_manager.load_tools().await?; + error!("## tool config: {:#?}", tool_config); let mut tool_permissions = ToolPermissions::new(tool_config.len()); if accept_all || trust_all_tools { tool_permissions.trust_all = true; @@ -797,14 +797,7 @@ impl ChatContext { debug!(?chat_state, "changing to state"); // Update conversation state with new tool information - if self - .conversation_state - .tool_manager - .has_new_stuff - .load(Ordering::Relaxed) - { - self.conversation_state.update_state().await; - } + self.conversation_state.update_state().await; let result = match chat_state { ChatState::PromptUser { diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 655d21b12f..b19cbbcfe3 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -280,7 +280,7 @@ impl ToolManagerBuilder { loading_servers.insert(name.clone(), status_line); let total = loading_servers.len(); execute!(output, terminal::Clear(terminal::ClearType::CurrentLine))?; - queue_init_message(spinner_logo_idx, complete, failed, total, is_interactive, &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; output.flush()?; }, LoadingMsg::Done(name) => { @@ -298,14 +298,7 @@ impl ToolManagerBuilder { )?; queue_success_message(&name, &time_taken, &mut output)?; let total = loading_servers.len(); - queue_init_message( - spinner_logo_idx, - complete, - failed, - total, - is_interactive, - &mut output, - )?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; output.flush()?; } if loading_servers.iter().all(|(_, status)| status.is_done) { @@ -324,14 +317,7 @@ impl ToolManagerBuilder { )?; queue_failure_message(&name, &msg, &mut output)?; let total = loading_servers.len(); - queue_init_message( - spinner_logo_idx, - complete, - failed, - total, - is_interactive, - &mut output, - )?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; } if loading_servers.iter().all(|(_, status)| status.is_done) { break; @@ -350,14 +336,7 @@ impl ToolManagerBuilder { let msg = eyre::eyre!(msg.to_string()); queue_warn_message(&name, &msg, &mut output)?; let total = loading_servers.len(); - queue_init_message( - spinner_logo_idx, - complete, - failed, - total, - is_interactive, - &mut output, - )?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; output.flush()?; } if loading_servers.iter().all(|(_, status)| status.is_done) { @@ -714,7 +693,7 @@ impl ToolManager { pub async fn load_tools(&mut self) -> eyre::Result> { let tx = self.loading_status_sender.take(); let display_task = self.loading_display_task.take(); - let mut tool_specs = { + self.schema = { let mut tool_specs = serde_json::from_str::>(include_str!("tools/tool_index.json"))?; if !crate::cli::chat::tools::thinking::Thinking::is_enabled() { @@ -777,8 +756,7 @@ impl ToolManager { } } self.update().await; - tool_specs.extend(self.schema.clone()); - Ok(tool_specs) + Ok(self.schema.clone()) } pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { @@ -1253,7 +1231,6 @@ fn queue_init_message( complete: usize, failed: usize, total: usize, - is_interactive: bool, output: &mut impl Write, ) -> eyre::Result<()> { if total == complete { @@ -1284,7 +1261,7 @@ fn queue_init_message( style::ResetColor, style::Print("mcp servers initialized."), )?; - if is_interactive { + if total > complete + failed { queue!( output, style::SetForegroundColor(style::Color::Blue), From c3900d9b3740c98370e51d0de4de8f8cb28a6e6c Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Mon, 12 May 2025 13:02:50 -0700 Subject: [PATCH 22/26] alias pkce to all uppercase --- crates/chat-cli/src/auth/builder_id.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs index e277c1410b..beb99c768c 100644 --- a/crates/chat-cli/src/auth/builder_id.rs +++ b/crates/chat-cli/src/auth/builder_id.rs @@ -62,7 +62,7 @@ use crate::database::secret_store::{ pub enum OAuthFlow { DeviceCode, // This must remain backwards compatible - #[serde(rename = "PKCE")] + #[serde(alias = "PKCE")] Pkce, } From f36c0bf38b007dfb862e5c0daee3c65bd8e2d235 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Mon, 12 May 2025 14:14:34 -0700 Subject: [PATCH 23/26] fixes test for oauth ser deser --- crates/chat-cli/src/auth/builder_id.rs | 22 ++++++++++++++++++- .../src/cli/chat/conversation_state.rs | 1 - 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/crates/chat-cli/src/auth/builder_id.rs b/crates/chat-cli/src/auth/builder_id.rs index beb99c768c..1ccb624784 100644 --- a/crates/chat-cli/src/auth/builder_id.rs +++ b/crates/chat-cli/src/auth/builder_id.rs @@ -58,7 +58,7 @@ use crate::database::secret_store::{ SecretStore, }; -#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)] pub enum OAuthFlow { DeviceCode, // This must remain backwards compatible @@ -66,6 +66,26 @@ pub enum OAuthFlow { Pkce, } +// Implement Serialize manually to ensure proper serialization +impl serde::Serialize for OAuthFlow { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match *self { + OAuthFlow::DeviceCode => serializer.serialize_str("DeviceCode"), + OAuthFlow::Pkce => serialize_pkce(serializer), + } + } +} + +fn serialize_pkce(serializer: S) -> Result +where + S: serde::Serializer, +{ + serializer.serialize_str("PKCE") +} + impl std::fmt::Display for OAuthFlow { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match *self { diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index 1f36055917..6d5aae8fb1 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -895,7 +895,6 @@ mod tests { }; use crate::cli::chat::tool_manager::ToolManager; use crate::database::Database; - use crate::platform::Env; fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { From 9d3d96877ecd31f43000e6cf2750bf906e938f87 Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Mon, 12 May 2025 14:45:55 -0700 Subject: [PATCH 24/26] puts timeout on telemetry finish --- crates/chat-cli/src/telemetry/mod.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/chat-cli/src/telemetry/mod.rs b/crates/chat-cli/src/telemetry/mod.rs index 430a327204..2dac77cbcc 100644 --- a/crates/chat-cli/src/telemetry/mod.rs +++ b/crates/chat-cli/src/telemetry/mod.rs @@ -36,6 +36,7 @@ pub use install_method::{ }; use tokio::sync::mpsc; use tokio::task::JoinHandle; +use tokio::time::error::Elapsed; use tracing::{ debug, error, @@ -75,6 +76,8 @@ pub enum TelemetryError { Join(#[from] tokio::task::JoinError), #[error(transparent)] Database(#[from] DatabaseError), + #[error(transparent)] + Timeout(#[from] Elapsed), } impl From for TelemetryError { @@ -159,7 +162,7 @@ impl TelemetryThread { pub async fn finish(self) -> Result<(), TelemetryError> { drop(self.tx); if let Some(handle) = self.handle { - handle.await?; + let _ = tokio::time::timeout(std::time::Duration::from_millis(100), handle).await; } Ok(()) From 192fb985443028b5bb63c7c814bd888148ba6b8f Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Mon, 12 May 2025 15:31:54 -0700 Subject: [PATCH 25/26] bumps telemetry finish timeout to 1 second and surface errors other than timeout --- crates/chat-cli/src/cli/chat/tools/custom_tool.rs | 1 - crates/chat-cli/src/telemetry/mod.rs | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 0ac886f519..95fe96d01c 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -107,7 +107,6 @@ impl CustomToolClient { } pub fn assign_messenger(&mut self, messenger: Box) { - tracing::error!("## background: assigned {} with messenger", self.get_server_name()); match self { CustomToolClient::Stdio { client, .. } => { client.messenger = Some(messenger); diff --git a/crates/chat-cli/src/telemetry/mod.rs b/crates/chat-cli/src/telemetry/mod.rs index 2dac77cbcc..84082b6c9e 100644 --- a/crates/chat-cli/src/telemetry/mod.rs +++ b/crates/chat-cli/src/telemetry/mod.rs @@ -162,7 +162,9 @@ impl TelemetryThread { pub async fn finish(self) -> Result<(), TelemetryError> { drop(self.tx); if let Some(handle) = self.handle { - let _ = tokio::time::timeout(std::time::Duration::from_millis(100), handle).await; + if let Err(e) = tokio::time::timeout(std::time::Duration::from_millis(1000), handle).await { + return Err(TelemetryError::Timeout(e)); + } } Ok(()) From b7b80ba4efa45fe0f5c9d5ee7a3bee299290f2dd Mon Sep 17 00:00:00 2001 From: Felix dingfeli Date: Mon, 12 May 2025 16:30:01 -0700 Subject: [PATCH 26/26] only surface error for telemetry finish if it's not a timeout --- crates/chat-cli/src/telemetry/mod.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/crates/chat-cli/src/telemetry/mod.rs b/crates/chat-cli/src/telemetry/mod.rs index 84082b6c9e..b49f4da592 100644 --- a/crates/chat-cli/src/telemetry/mod.rs +++ b/crates/chat-cli/src/telemetry/mod.rs @@ -162,8 +162,15 @@ impl TelemetryThread { pub async fn finish(self) -> Result<(), TelemetryError> { drop(self.tx); if let Some(handle) = self.handle { - if let Err(e) = tokio::time::timeout(std::time::Duration::from_millis(1000), handle).await { - return Err(TelemetryError::Timeout(e)); + match tokio::time::timeout(std::time::Duration::from_millis(1000), handle).await { + Ok(result) => { + if let Err(e) = result { + return Err(TelemetryError::Join(e)); + } + }, + Err(_) => { + // Ignore timeout errors + }, } }