diff --git a/crates/chat-cli/src/cli/chat/conversation_state.rs b/crates/chat-cli/src/cli/chat/conversation_state.rs index bb269c286c..6d5aae8fb1 100644 --- a/crates/chat-cli/src/cli/chat/conversation_state.rs +++ b/crates/chat-cli/src/cli/chat/conversation_state.rs @@ -3,6 +3,7 @@ use std::collections::{ VecDeque, }; use std::sync::Arc; +use std::sync::atomic::Ordering; use crossterm::style::Color; use crossterm::{ @@ -41,6 +42,7 @@ use super::token_counter::{ CharCount, CharCounter, }; +use super::tool_manager::ToolManager; use super::tools::{ InputSchema, QueuedTool, @@ -90,6 +92,9 @@ pub struct ConversationState { pub tools: HashMap>, /// Context manager for handling sticky context files pub context_manager: Option, + /// Tool manager for handling tool and mcp related activities + #[serde(skip)] + pub tool_manager: ToolManager, /// Cached value representing the length of the user context message. context_message_length: Option, /// Stores the latest conversation summary created by /compact @@ -105,6 +110,7 @@ impl ConversationState { tool_config: HashMap, profile: Option, updates: Option, + tool_manager: ToolManager, ) -> Self { // Initialize context manager let context_manager = match ContextManager::new(ctx, None).await { @@ -143,6 +149,7 @@ impl ConversationState { acc }), context_manager, + tool_manager, context_message_length: None, latest_summary: None, updates, @@ -310,29 +317,49 @@ impl ConversationState { } // Here we also need to make sure that the tool result corresponds to one of the tools - // in the list. Otherwise we will see validation error from the backend. We would only - // do this if the last message is a tool call that has failed. + // in the list. Otherwise we will see validation error from the backend. There are three + // such circumstances where intervention would be needed: + // 1. The model had decided to call a tool with its partial name AND there is only one such tool, in + // which case we would automatically resolve this tool call to its correct name. This will NOT + // result in an error in its tool result. The intervention here is to substitute the partial name + // with its full name. + // 2. The model had decided to call a tool with its partial name AND there are multiple tools it + // could be referring to, in which case we WILL return an error in the tool result. The + // intervention here is to substitute the ambiguous, partial name with a dummy. + // 3. The model had decided to call a tool that does not exist. The intervention here is to + // substitute the non-existent tool name with a dummy. let tool_use_results = user_msg.tool_use_results(); if let Some(tool_use_results) = tool_use_results { - let tool_name_list = self - .tools - .values() - .flatten() - .map(|Tool::ToolSpecification(spec)| spec.name.as_str()) - .collect::>(); + // Note that we need to use the keys in tool manager's tn_map as the keys are the + // actual tool names as exposed to the model and the backend. If we use the actual + // names as they are recognized by their respective servers, we risk concluding + // with false positives. + let tool_name_list = self.tool_manager.tn_map.keys().map(String::as_str).collect::>(); for result in tool_use_results { - if let ToolResultStatus::Error = result.status { - let tool_use_id = result.tool_use_id.as_str(); - let _ = tool_uses - .iter_mut() - .filter(|tool_use| tool_use.id == tool_use_id) - .map(|tool_use| { - let tool_name = tool_use.name.as_str(); - if !tool_name_list.contains(&tool_name) { - tool_use.name = DUMMY_TOOL_NAME.to_string(); - } - }) - .collect::>(); + let tool_use_id = result.tool_use_id.as_str(); + let corresponding_tool_use = tool_uses.iter_mut().find(|tool_use| tool_use_id == tool_use.id); + if let Some(tool_use) = corresponding_tool_use { + if tool_name_list.contains(&tool_use.name.as_str()) { + // If this tool matches of the tools in our list, this is not our + // concern, error or not. + continue; + } + if let ToolResultStatus::Error = result.status { + // case 2 and 3 + tool_use.name = DUMMY_TOOL_NAME.to_string(); + tool_use.args = serde_json::json!({}); + } else { + // case 1 + let full_name = tool_name_list.iter().find(|name| name.ends_with(&tool_use.name)); + // We should be able to find a match but if not we'll just treat it as + // a dummy and move on + if let Some(full_name) = full_name { + tool_use.name = (*full_name).to_string(); + } else { + tool_use.name = DUMMY_TOOL_NAME.to_string(); + tool_use.args = serde_json::json!({}); + } + } } } } @@ -363,6 +390,7 @@ impl ConversationState { /// - `run_hooks` - whether hooks should be executed and included as context pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState { debug_assert!(self.next_message.is_some()); + self.update_state().await; self.enforce_conversation_invariants(); self.history.drain(self.valid_history_range.1..); self.history.drain(..self.valid_history_range.0); @@ -388,6 +416,30 @@ impl ConversationState { .expect("unable to construct conversation state") } + pub async fn update_state(&mut self) { + let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire); + if !needs_update { + return; + } + self.tool_manager.update().await; + self.tools = self + .tool_manager + .schema + .values() + .fold(HashMap::>::new(), |mut acc, v| { + let tool = Tool::ToolSpecification(ToolSpecification { + name: v.name.clone(), + description: v.description.clone(), + input_schema: v.input_schema.clone().into(), + }); + acc.entry(v.tool_origin.clone()) + .and_modify(|tools| tools.push(tool.clone())) + .or_insert(vec![tool]); + acc + }); + self.tool_manager.has_new_stuff.store(false, Ordering::Release); + } + /// Returns a conversation state representation which reflects the exact conversation to send /// back to the model. pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> { @@ -843,8 +895,6 @@ mod tests { }; use crate::cli::chat::tool_manager::ToolManager; use crate::database::Database; - use crate::platform::Env; - use crate::telemetry::TelemetryThread; fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { @@ -936,17 +986,16 @@ mod tests { #[tokio::test] async fn test_conversation_state_history_handling_truncation() { - let env = Env::new(); let mut database = Database::new().await.unwrap(); - let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); let mut tool_manager = ToolManager::default(); let mut conversation_state = ConversationState::new( Context::new(), "fake_conv_id", - tool_manager.load_tools(&database, &telemetry).await.unwrap(), + tool_manager.load_tools(&database).await.unwrap(), None, None, + tool_manager, ) .await; @@ -964,18 +1013,18 @@ mod tests { #[tokio::test] async fn test_conversation_state_history_handling_with_tool_results() { - let env = Env::new(); let mut database = Database::new().await.unwrap(); - let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); // Build a long conversation history of tool use results. let mut tool_manager = ToolManager::default(); + let tool_config = tool_manager.load_tools(&database).await.unwrap(); let mut conversation_state = ConversationState::new( Context::new(), "fake_conv_id", - tool_manager.load_tools(&database, &telemetry).await.unwrap(), + tool_config.clone(), None, None, + tool_manager.clone(), ) .await; conversation_state.set_next_user_message("start".to_string()).await; @@ -1002,9 +1051,10 @@ mod tests { let mut conversation_state = ConversationState::new( Context::new(), "fake_conv_id", - tool_manager.load_tools(&database, &telemetry).await.unwrap(), + tool_config.clone(), None, None, + tool_manager.clone(), ) .await; conversation_state.set_next_user_message("start".to_string()).await; @@ -1035,9 +1085,7 @@ mod tests { #[tokio::test] async fn test_conversation_state_with_context_files() { - let env = Env::new(); let mut database = Database::new().await.unwrap(); - let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); ctx.fs().write(AMAZONQ_FILENAME, "test context").await.unwrap(); @@ -1046,9 +1094,10 @@ mod tests { let mut conversation_state = ConversationState::new( ctx, "fake_conv_id", - tool_manager.load_tools(&database, &telemetry).await.unwrap(), + tool_manager.load_tools(&database).await.unwrap(), None, None, + tool_manager, ) .await; @@ -1085,9 +1134,7 @@ mod tests { async fn test_conversation_state_additional_context() { // tracing_subscriber::fmt::try_init().ok(); - let env = Env::new(); let mut database = Database::new().await.unwrap(); - let telemetry = TelemetryThread::new(&env, &mut database).await.unwrap(); let mut tool_manager = ToolManager::default(); let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); @@ -1116,9 +1163,10 @@ mod tests { let mut conversation_state = ConversationState::new( ctx, "fake_conv_id", - tool_manager.load_tools(&database, &telemetry).await.unwrap(), + tool_manager.load_tools(&database).await.unwrap(), None, Some(SharedWriter::stdout()), + tool_manager, ) .await; diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 68a62e088c..520bb4795c 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -9,6 +9,7 @@ mod message; mod parse; mod parser; mod prompt; +mod server_messenger; #[cfg(unix)] mod skim_integration; mod token_counter; @@ -83,7 +84,10 @@ use rand::distr::{ SampleString, }; use tokio::signal::ctrl_c; -use util::shared_writer::SharedWriter; +use util::shared_writer::{ + NullWriter, + SharedWriter, +}; use util::ui::draw_box; use crate::api_client::StreamingClient; @@ -198,7 +202,8 @@ const WELCOME_TEXT: &str = color_print::cstr! {" const SMALL_SCREEN_WELCOME_TEXT: &str = color_print::cstr! {"Welcome to Amazon Q!"}; const RESUME_TEXT: &str = color_print::cstr! {"Picking up where we left off..."}; -const ROTATING_TIPS: [&str; 9] = [ +const ROTATING_TIPS: [&str; 11] = [ + color_print::cstr! {"Get notified whenever Q CLI finishes responding. Just run q settings chat.enableNotifications true"}, color_print::cstr! {"You can use /editor to edit your prompt with a vim-like experience"}, color_print::cstr! {"/usage shows you a visual breakdown of your current context window usage"}, color_print::cstr! {"Get notified whenever Q CLI finishes responding. Just run q settings chat.enableNotifications true"}, @@ -207,7 +212,8 @@ const ROTATING_TIPS: [&str; 9] = [ color_print::cstr! {"You can programmatically inject context to your prompts by using hooks. Check out /context hooks help"}, color_print::cstr! {"You can use /compact to replace the conversation history with its summary to free up the context space"}, color_print::cstr! {"If you want to file an issue to the Q CLI team, just tell me, or run q issue"}, - color_print::cstr! {"You can enable custom tools with MCP servers. Learn more with /help"}, + color_print::cstr! {"You can enable custom tools with MCP servers. Learn more with /help"}, + color_print::cstr! {"You can specify wait time (in ms) for mcp server loading with q settings mcp.initTimeout {timeout in int}. Servers that takes longer than the specified time will continue to load in the background. Use /tools to see pending servers."}, ]; const GREETING_BREAK_POINT: usize = 80; @@ -381,16 +387,23 @@ pub async fn chat( info!(?conversation_id, "Generated new conversation id"); let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::>(); let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let tool_manager_output: Box = if interactive { + Box::new(output.clone()) + } else { + Box::new(NullWriter {}) + }; let mut tool_manager = ToolManagerBuilder::default() .mcp_server_config(mcp_server_configs) .prompt_list_sender(prompt_response_sender) .prompt_list_receiver(prompt_request_receiver) .conversation_id(&conversation_id) - .build(telemetry) + .interactive(interactive) + .build(telemetry, tool_manager_output) .await?; - let tool_config = tool_manager.load_tools(database, telemetry).await?; + let tool_config = tool_manager.load_tools(database).await?; let mut tool_permissions = ToolPermissions::new(tool_config.len()); if accept_all || trust_all_tools { + tool_permissions.trust_all = true; for tool in tool_config.values() { tool_permissions.trust_tool(&tool.name); } @@ -493,8 +506,6 @@ pub struct ChatContext { tool_use_telemetry_events: HashMap, /// State used to keep track of tool use relation tool_use_status: ToolUseStatus, - /// Abstraction that consolidates custom tools with native ones - tool_manager: ToolManager, /// Any failed requests that could be useful for error report/debugging failed_request_ids: Vec, /// Pending prompts to be sent @@ -527,12 +538,23 @@ impl ChatContext { .and_then(|cwd| database.get_conversation_by_path(cwd).ok()) .flatten() { - Some(prior) => { + Some(mut prior) => { existing_conversation = true; input = Some(input.unwrap_or("In a few words, summarize our conversation so far.".to_owned())); + prior.tool_manager = tool_manager; prior }, - None => ConversationState::new(ctx_clone, conversation_id, tool_config, profile, Some(output_clone)).await, + None => { + ConversationState::new( + ctx_clone, + conversation_id, + tool_config, + profile, + Some(output_clone), + tool_manager, + ) + .await + }, }; Ok(Self { @@ -549,7 +571,6 @@ impl ChatContext { conversation_state, tool_use_telemetry_events: HashMap::new(), tool_use_status: ToolUseStatus::Idle, - tool_manager, failed_request_ids: Vec::new(), pending_prompts: VecDeque::new(), }) @@ -764,6 +785,9 @@ impl ChatContext { let ctrl_c_stream = ctrl_c(); debug!(?chat_state, "changing to state"); + // Update conversation state with new tool information + self.conversation_state.update_state().await; + let result = match chat_state { ChatState::PromptUser { tool_uses, @@ -1206,6 +1230,7 @@ impl ChatContext { #[cfg(unix)] if let Some(ref context_manager) = self.conversation_state.context_manager { let tool_names = self + .conversation_state .tool_manager .tn_map .keys() @@ -2176,9 +2201,10 @@ impl ChatContext { match subcommand { Some(ToolsSubcommand::Schema) => { - let schema_json = serde_json::to_string_pretty(&self.tool_manager.schema).map_err(|e| { - ChatError::Custom(format!("Error converting tool schema to string: {e}").into()) - })?; + let schema_json = serde_json::to_string_pretty(&self.conversation_state.tool_manager.schema) + .map_err(|e| { + ChatError::Custom(format!("Error converting tool schema to string: {e}").into()) + })?; queue!(self.output, style::Print(schema_json), style::Print("\n"))?; }, Some(ToolsSubcommand::Trust { tool_names }) => { @@ -2277,7 +2303,7 @@ impl ChatContext { )?; }, Some(ToolsSubcommand::ResetSingle { tool_name }) => { - if self.tool_permissions.has(&tool_name) { + if self.tool_permissions.has(&tool_name) || self.tool_permissions.trust_all { self.tool_permissions.reset_tool(&tool_name); queue!( self.output, @@ -2359,6 +2385,21 @@ impl ChatContext { ); }); + let loading = self.conversation_state.tool_manager.pending_clients().await; + if !loading.is_empty() { + queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::Print("Servers still loading"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + style::Print("▔".repeat(terminal_width)), + )?; + for client in loading { + queue!(self.output, style::Print(format!(" - {client}")), style::Print("\n"))?; + } + } + queue!( self.output, style::Print("\nTrusted tools can be run without confirmation\n"), @@ -2392,7 +2433,7 @@ impl ChatContext { }, Some(PromptsSubcommand::Get { mut get_command }) => { let orig_input = get_command.orig_input.take(); - let prompts = match self.tool_manager.get_prompt(get_command).await { + let prompts = match self.conversation_state.tool_manager.get_prompt(get_command).await { Ok(resp) => resp, Err(e) => { match e { @@ -2479,12 +2520,12 @@ impl ChatContext { _ => None, }; let terminal_width = self.terminal_width(); - let mut prompts_wl = self.tool_manager.prompts.write().map_err(|e| { + let mut prompts_wl = self.conversation_state.tool_manager.prompts.write().map_err(|e| { ChatError::Custom( format!("Poison error encountered while retrieving prompts: {}", e).into(), ) })?; - self.tool_manager.refresh_prompts(&mut prompts_wl)?; + self.conversation_state.tool_manager.refresh_prompts(&mut prompts_wl)?; let mut longest_name = ""; let arg_pos = { let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4; @@ -2852,11 +2893,9 @@ impl ChatContext { } // If there is an override, we will use it. Otherwise fall back to Tool's default. - let allowed = if self.tool_permissions.has(&tool.name) { - self.tool_permissions.is_trusted(&tool.name) - } else { - !tool.tool.requires_acceptance(&self.ctx) - }; + let allowed = self.tool_permissions.trust_all + || (self.tool_permissions.has(&tool.name) && self.tool_permissions.is_trusted(&tool.name)) + || !tool.tool.requires_acceptance(&self.ctx); if database .settings @@ -3273,7 +3312,7 @@ impl ChatContext { .set_tool_use_id(tool_use_id.clone()) .set_tool_name(tool_use.name.clone()) .utterance_id(self.conversation_state.message_id().map(|s| s.to_string())); - match self.tool_manager.get_tool_from_tool_use(tool_use) { + match self.conversation_state.tool_manager.get_tool_from_tool_use(tool_use) { Ok(mut tool) => { // Apply non-Q-generated context to tools self.contextualize_tool(&mut tool); diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs new file mode 100644 index 0000000000..3adc665d15 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/server_messenger.rs @@ -0,0 +1,130 @@ +use tokio::sync::mpsc::{ + Receiver, + Sender, + channel, +}; + +use crate::mcp_client::{ + Messenger, + MessengerError, + PromptsListResult, + ResourceTemplatesListResult, + ResourcesListResult, + ToolsListResult, +}; + +#[allow(dead_code)] +#[derive(Clone, Debug)] +pub enum UpdateEventMessage { + ToolsListResult { + server_name: String, + result: ToolsListResult, + }, + PromptsListResult { + server_name: String, + result: PromptsListResult, + }, + ResourcesListResult { + server_name: String, + result: ResourcesListResult, + }, + ResourceTemplatesListResult { + server_name: String, + result: ResourceTemplatesListResult, + }, + InitStart { + server_name: String, + }, +} + +#[derive(Clone, Debug)] +pub struct ServerMessengerBuilder { + pub update_event_sender: Sender, +} + +impl ServerMessengerBuilder { + pub fn new(capacity: usize) -> (Receiver, Self) { + let (tx, rx) = channel::(capacity); + let this = Self { + update_event_sender: tx, + }; + (rx, this) + } + + pub fn build_with_name(&self, server_name: String) -> ServerMessenger { + ServerMessenger { + server_name, + update_event_sender: self.update_event_sender.clone(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ServerMessenger { + pub server_name: String, + pub update_event_sender: Sender, +} + +#[async_trait::async_trait] +impl Messenger for ServerMessenger { + async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::ToolsListResult { + server_name: self.server_name.clone(), + result, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + async fn send_prompts_list_result(&self, result: PromptsListResult) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::PromptsListResult { + server_name: self.server_name.clone(), + result, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + async fn send_resources_list_result(&self, result: ResourcesListResult) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::ResourcesListResult { + server_name: self.server_name.clone(), + result, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + async fn send_resource_templates_list_result( + &self, + result: ResourceTemplatesListResult, + ) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::ResourceTemplatesListResult { + server_name: self.server_name.clone(), + result, + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + async fn send_init_msg(&self) -> Result<(), MessengerError> { + Ok(self + .update_event_sender + .send(UpdateEventMessage::InitStart { + server_name: self.server_name.clone(), + }) + .await + .map_err(|e| MessengerError::Custom(e.to_string()))?) + } + + fn duplicate(&self) -> Box { + Box::new(self.clone()) + } +} diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index e7bc910ac9..cbaf9be1ec 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -1,10 +1,19 @@ -use std::collections::HashMap; +use std::collections::{ + HashMap, + HashSet, +}; +use std::future::Future; use std::hash::{ DefaultHasher, Hasher, }; use std::io::Write; use std::path::PathBuf; +use std::pin::Pin; +use std::sync::atomic::{ + AtomicBool, + Ordering, +}; use std::sync::mpsc::RecvTimeoutError; use std::sync::{ Arc, @@ -21,40 +30,54 @@ use crossterm::{ }; use futures::{ StreamExt, + future, stream, }; +use regex::Regex; use serde::{ Deserialize, Serialize, }; use thiserror::Error; -use tokio::sync::Mutex; -use tracing::error; +use tokio::signal::ctrl_c; +use tokio::sync::{ + Mutex, + RwLock, +}; +use tracing::{ + error, + warn, +}; -use super::command::PromptsGetCommand; -use super::message::AssistantToolUse; -use super::tools::custom_tool::{ +use crate::api_client::model::{ + ToolResult, + ToolResultContentBlock, + ToolResultStatus, +}; +use crate::cli::chat::command::PromptsGetCommand; +use crate::cli::chat::message::AssistantToolUse; +use crate::cli::chat::server_messenger::{ + ServerMessengerBuilder, + UpdateEventMessage, +}; +use crate::cli::chat::tools::custom_tool::{ CustomTool, CustomToolClient, CustomToolConfig, }; -use super::tools::execute_bash::ExecuteBash; -use super::tools::fs_read::FsRead; -use super::tools::fs_write::FsWrite; -use super::tools::gh_issue::GhIssue; -use super::tools::thinking::Thinking; -use super::tools::use_aws::UseAws; -use super::tools::{ +use crate::cli::chat::tools::execute_bash::ExecuteBash; +use crate::cli::chat::tools::fs_read::FsRead; +use crate::cli::chat::tools::fs_write::FsWrite; +use crate::cli::chat::tools::gh_issue::GhIssue; +use crate::cli::chat::tools::thinking::Thinking; +use crate::cli::chat::tools::use_aws::UseAws; +use crate::cli::chat::tools::{ Tool, ToolOrigin, ToolSpec, }; -use crate::api_client::model::{ - ToolResult, - ToolResultContentBlock, - ToolResultStatus, -}; use crate::database::Database; +use crate::database::settings::Setting; use crate::mcp_client::{ JsonRpcResponse, PromptGet, @@ -102,6 +125,10 @@ enum LoadingMsg { /// Represents a warning that occurred during tool initialization. /// Contains the name of the server that generated the warning and the warning message. Warn { name: String, msg: eyre::Report }, + /// Signals that the loading display thread should terminate. + /// This is sent when all tool initialization is complete or when the application is shutting + /// down. + Terminate, } /// Represents the state of a loading indicator for a tool being initialized. @@ -112,6 +139,7 @@ enum LoadingMsg { /// * `init_time` - When initialization for this tool began, used to calculate load time struct StatusLine { init_time: std::time::Instant, + is_done: bool, } // This is to mirror claude's config set up @@ -182,6 +210,7 @@ pub struct ToolManagerBuilder { prompt_list_sender: Option>>, prompt_list_receiver: Option>>, conversation_id: Option, + is_interactive: bool, } impl ToolManagerBuilder { @@ -205,12 +234,22 @@ impl ToolManagerBuilder { self } - pub async fn build(mut self, telemetry: &TelemetryThread) -> eyre::Result { + pub fn interactive(mut self, is_interactive: bool) -> Self { + self.is_interactive = is_interactive; + self + } + + pub async fn build( + mut self, + telemetry: &TelemetryThread, + mut output: Box, + ) -> eyre::Result { let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; debug_assert!(self.conversation_id.is_some()); let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; let regex = regex::Regex::new(VALID_TOOL_NAME)?; let mut hasher = DefaultHasher::new(); + let is_interactive = self.is_interactive; let pre_initialized = mcp_servers .into_iter() .map(|(server_name, server_config)| { @@ -223,11 +262,11 @@ impl ToolManagerBuilder { // Send up task to update user on server loading status let (tx, rx) = std::sync::mpsc::channel::(); - // Using a hand rolled thread because it's just easier to do this than do deal with the Send - // requirements that comes with holding onto the stdout lock. - let loading_display_task = std::thread::spawn(move || { - let stdout = std::io::stdout(); - let mut stdout_lock = stdout.lock(); + // TODO: rather than using it as an "anchor" to determine the progress of server loads, we + // should make this task optional (and it is defined as an optional right now. There is + // just no code path with it being None). When ran with no-interactive mode, we really do + // not have a need to run this task. + let loading_display_task = tokio::task::spawn_blocking(move || { let mut loading_servers = HashMap::::new(); let mut spinner_logo_idx: usize = 0; let mut complete: usize = 0; @@ -237,67 +276,108 @@ impl ToolManagerBuilder { Ok(recv_result) => match recv_result { LoadingMsg::Add(name) => { let init_time = std::time::Instant::now(); - let status_line = StatusLine { init_time }; - execute!(stdout_lock, cursor::MoveToColumn(0))?; + let is_done = false; + let status_line = StatusLine { init_time, is_done }; + execute!(output, cursor::MoveToColumn(0))?; if !loading_servers.is_empty() { // TODO: account for terminal width - execute!(stdout_lock, cursor::MoveUp(1))?; + execute!(output, cursor::MoveUp(1))?; } loading_servers.insert(name.clone(), status_line); let total = loading_servers.len(); - execute!(stdout_lock, terminal::Clear(terminal::ClearType::CurrentLine))?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; + execute!(output, terminal::Clear(terminal::ClearType::CurrentLine))?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + output.flush()?; }, LoadingMsg::Done(name) => { - if let Some(status_line) = loading_servers.get(&name) { + if let Some(status_line) = loading_servers.get_mut(&name) { + status_line.is_done = true; complete += 1; let time_taken = (std::time::Instant::now() - status_line.init_time).as_secs_f64().abs(); let time_taken = format!("{:.2}", time_taken); execute!( - stdout_lock, + output, cursor::MoveToColumn(0), cursor::MoveUp(1), terminal::Clear(terminal::ClearType::CurrentLine), )?; - queue_success_message(&name, &time_taken, &mut stdout_lock)?; + queue_success_message(&name, &time_taken, &mut output)?; let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + output.flush()?; + } + if loading_servers.iter().all(|(_, status)| status.is_done) { + break; } }, LoadingMsg::Error { name, msg } => { - failed += 1; - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_failure_message(&name, &msg, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + if let Some(status_line) = loading_servers.get_mut(&name) { + status_line.is_done = true; + failed += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_failure_message(&name, &msg, &mut output)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + } + if loading_servers.iter().all(|(_, status)| status.is_done) { + break; + } }, LoadingMsg::Warn { name, msg } => { - complete += 1; - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = eyre::eyre!(msg.to_string()); - queue_warn_message(&name, &msg, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; + if let Some(status_line) = loading_servers.get_mut(&name) { + status_line.is_done = true; + complete += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = eyre::eyre!(msg.to_string()); + queue_warn_message(&name, &msg, &mut output)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + output.flush()?; + } + if loading_servers.iter().all(|(_, status)| status.is_done) { + break; + } + }, + LoadingMsg::Terminate => { + if loading_servers.iter().any(|(_, status)| !status.is_done) { + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = + loading_servers + .iter() + .fold(String::new(), |mut acc, (server_name, status)| { + if !status.is_done { + acc.push_str(format!("\n - {server_name}").as_str()); + } + acc + }); + let msg = eyre::eyre!(msg); + let total = loading_servers.len(); + queue_incomplete_load_message(complete, total, &msg, &mut output)?; + output.flush()?; + } + break; }, }, Err(RecvTimeoutError::Timeout) => { spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); execute!( - stdout_lock, + output, cursor::SavePosition, cursor::MoveToColumn(0), cursor::MoveUp(1), @@ -311,10 +391,90 @@ impl ToolManagerBuilder { Ok::<_, eyre::Report>(()) }); let mut clients = HashMap::>::new(); + let mut load_msg_sender = Some(tx.clone()); + let conv_id_clone = conversation_id.clone(); + let regex = Arc::new(Regex::new(VALID_TOOL_NAME)?); + let new_tool_specs = Arc::new(Mutex::new(HashMap::new())); + let new_tool_specs_clone = new_tool_specs.clone(); + let has_new_stuff = Arc::new(AtomicBool::new(false)); + let has_new_stuff_clone = has_new_stuff.clone(); + let pending = Arc::new(RwLock::new(HashSet::::new())); + let pending_clone = pending.clone(); + let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); + let telemetry_clone = telemetry.clone(); + tokio::spawn(async move { + while let Some(msg) = msg_rx.recv().await { + // For now we will treat every list result as if they contain the + // complete set of tools. This is not necessarily true in the future when + // request method on the mcp client no longer buffers all the pages from + // list calls. + match msg { + UpdateEventMessage::ToolsListResult { server_name, result } => { + pending_clone.write().await.remove(&server_name); + let mut specs = result + .tools + .into_iter() + .filter_map(|v| serde_json::from_value::(v).ok()) + .collect::>(); + let mut sanitized_mapping = HashMap::::new(); + if let Some(load_msg) = process_tool_specs( + conv_id_clone.as_str(), + &server_name, + load_msg_sender.is_some(), + &mut specs, + &mut sanitized_mapping, + ®ex, + &telemetry_clone, + ) { + let mut has_errored = false; + if let Some(sender) = &load_msg_sender { + if let Err(e) = sender.send(load_msg) { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + has_errored = true; + } + } + if has_errored { + load_msg_sender.take(); + } + } + new_tool_specs_clone + .lock() + .await + .insert(server_name, (sanitized_mapping, specs)); + // We only want to set this flag when the display task has ended + if load_msg_sender.is_none() { + has_new_stuff_clone.store(true, Ordering::Release); + } + }, + UpdateEventMessage::PromptsListResult { + server_name: _, + result: _, + } => {}, + UpdateEventMessage::ResourcesListResult { + server_name: _, + result: _, + } => {}, + UpdateEventMessage::ResourceTemplatesListResult { + server_name: _, + result: _, + } => {}, + UpdateEventMessage::InitStart { server_name } => { + pending_clone.write().await.insert(server_name.clone()); + if let Some(sender) = &load_msg_sender { + let _ = sender.send(LoadingMsg::Add(server_name)); + } + }, + } + } + }); for (mut name, init_res) in pre_initialized { - let _ = tx.send(LoadingMsg::Add(name.clone())); match init_res { - Ok(client) => { + Ok(mut client) => { + let messenger = messenger_builder.build_with_name(client.get_server_name().to_owned()); + client.assign_messenger(Box::new(messenger)); let mut client = Arc::new(client); while let Some(collided_client) = clients.insert(name.clone(), client) { // to avoid server name collision we are going to circumvent this by @@ -433,7 +593,11 @@ impl ToolManagerBuilder { clients, prompts, loading_display_task, + pending_clients: pending, loading_status_sender, + new_tool_specs, + has_new_stuff, + is_interactive, ..Default::default() }) } @@ -461,10 +625,12 @@ enum OutOfSpecName { EmptyDescription(String), } +type NewToolSpecs = Arc, Vec)>>>; + +#[derive(Default, Debug)] /// Manages the lifecycle and interactions with tools from various sources, including MCP servers. /// This struct is responsible for initializing tools, handling tool requests, and maintaining /// a cache of available prompts from connected servers. -#[derive(Default)] pub struct ToolManager { /// Unique identifier for the current conversation. /// This ID is used to track and associate tools with a specific chat session. @@ -474,6 +640,20 @@ pub struct ToolManager { /// These clients are used to communicate with MCP servers. pub clients: HashMap>, + /// A list of client names that are still in the process of being initialized + pub pending_clients: Arc>>, + + /// Flag indicating whether new tool specifications have been added since the last update. + /// When set to true, it signals that the tool manager needs to refresh its internal state + /// to incorporate newly available tools from MCP servers. + pub has_new_stuff: Arc, + + /// Storage for newly discovered tool specifications from MCP servers that haven't yet been + /// integrated into the main tool registry. This field holds a thread-safe reference to a map + /// of server names to their tool specifications and name mappings, allowing concurrent updates + /// from server initialization processes. + new_tool_specs: NewToolSpecs, + /// Cache for prompts collected from different servers. /// Key: prompt name /// Value: a list of PromptBundle that has a prompt of this name. @@ -483,7 +663,7 @@ pub struct ToolManager { /// Handle to the thread that displays loading status for tool initialization. /// This thread provides visual feedback to users during the tool loading process. - loading_display_task: Option>>, + loading_display_task: Option>>, /// Channel sender for communicating with the loading display thread. /// Used to send status updates about tool initialization progress. @@ -498,177 +678,96 @@ pub struct ToolManager { /// This is mainly used to show the user what the tools look like from the perspective of the /// model. pub schema: HashMap, + + is_interactive: bool, +} + +impl Clone for ToolManager { + fn clone(&self) -> Self { + Self { + conversation_id: self.conversation_id.clone(), + clients: self.clients.clone(), + has_new_stuff: self.has_new_stuff.clone(), + new_tool_specs: self.new_tool_specs.clone(), + prompts: self.prompts.clone(), + tn_map: self.tn_map.clone(), + schema: self.schema.clone(), + is_interactive: self.is_interactive, + ..Default::default() + } + } } impl ToolManager { - pub async fn load_tools( - &mut self, - database: &Database, - telemetry: &TelemetryThread, - ) -> eyre::Result> { + pub async fn load_tools(&mut self, database: &Database) -> eyre::Result> { let tx = self.loading_status_sender.take(); let display_task = self.loading_display_task.take(); - let tool_specs = { + self.schema = { let mut tool_specs = serde_json::from_str::>(include_str!("tools/tool_index.json"))?; if !crate::cli::chat::tools::thinking::Thinking::is_enabled(database) { tool_specs.remove("thinking"); } - Arc::new(Mutex::new(tool_specs)) + tool_specs }; - let conversation_id = self.conversation_id.clone(); - let regex = Arc::new(regex::Regex::new(VALID_TOOL_NAME)?); - - let load_tool = self + let load_tools = self .clients - .iter() - .map(|(server_name, client)| { - let telemetry = telemetry.clone(); - let client_clone = client.clone(); - let server_name_clone = server_name.clone(); - let tx_clone = tx.clone(); - let regex_clone = regex.clone(); - let tool_specs_clone = tool_specs.clone(); - let conversation_id = conversation_id.clone(); - async move { - let tool_spec = client_clone.init().await; - let mut sanitized_mapping = HashMap::::new(); - match tool_spec { - Ok((server_name, specs)) => { - // Each mcp server might have multiple tools. - // To avoid naming conflicts we are going to namespace it. - // This would also help us locate which mcp server to call the tool from. - let mut out_of_spec_tool_names = Vec::::new(); - let mut hasher = DefaultHasher::new(); - let number_of_tools = specs.len(); - // Sanitize tool names to ensure they comply with the naming requirements: - // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use it as is - // 2. Otherwise, remove invalid characters and handle special cases: - // - Remove namespace delimiters - // - Ensure the name starts with an alphabetic character - // - Generate a hash-based name if the sanitized result is empty - // This ensures all tool names are valid identifiers that can be safely used in the system - // If after all of the aforementioned modification the combined tool - // name we have exceeds a length of 64, we surface it as an error - for mut spec in specs { - let sn = if !regex_clone.is_match(&spec.name) { - let mut sn = sanitize_name(spec.name.clone(), ®ex_clone, &mut hasher); - while sanitized_mapping.contains_key(&sn) { - sn.push('1'); - } - sn - } else { - spec.name.clone() - }; - let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); - if full_name.len() > 64 { - out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name)); - continue; - } else if spec.description.is_empty() { - out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name)); - continue; - } - if sn != spec.name { - sanitized_mapping.insert(full_name.clone(), format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name)); - } - spec.name = full_name; - spec.tool_origin = ToolOrigin::McpServer(server_name.clone()); - tool_specs_clone.lock().await.insert(spec.name.clone(), spec); - } - - // Send server load success metric datum - telemetry.send_mcp_server_init(conversation_id, None, number_of_tools).ok(); - - // Tool name translation. This is beyond of the scope of what is - // considered a "server load". Reasoning being: - // - Failures here are not related to server load - // - There is not a whole lot we can do with this data - if let Some(tx_clone) = &tx_clone { - let send_result = if !out_of_spec_tool_names.is_empty() { - let msg = out_of_spec_tool_names.iter().fold( - String::from("The following tools are out of spec. They will be excluded from the list of available tools:\n"), - |mut acc, name| { - let (tool_name, msg) = match name { - OutOfSpecName::TooLong(tool_name) => (tool_name.as_str(), "tool name exceeds max length of 64 when combined with server name"), - OutOfSpecName::IllegalChar(tool_name) => (tool_name.as_str(), "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$"), - OutOfSpecName::EmptyDescription(tool_name) => (tool_name.as_str(), "tool schema contains empty description"), - }; - acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); - acc - } - ); - tx_clone.send(LoadingMsg::Error { - name: server_name.clone(), - msg: eyre::eyre!(msg), - }) - // TODO: if no tools are valid, we need to offload the server - // from the fleet (i.e. kill the server) - } else if !sanitized_mapping.is_empty() { - let warn = sanitized_mapping.iter().fold(String::from("The following tool names are changed:\n"), |mut acc, (k, v)| { - acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); - acc - }); - tx_clone.send(LoadingMsg::Warn { - name: server_name.clone(), - msg: eyre::eyre!(warn), - }) - } else { - tx_clone.send(LoadingMsg::Done(server_name.clone())) - }; - if let Err(e) = send_result { - error!("Error while sending status update to display task: {:?}", e); - } - } - }, - Err(e) => { - error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); - let init_failure_reason = Some(e.to_string()); - telemetry.send_mcp_server_init(conversation_id, init_failure_reason, 0).ok(); - if let Some(tx_clone) = &tx_clone { - if let Err(e) = tx_clone.send(LoadingMsg::Error { - name: server_name_clone, - msg: e, - }) { - error!("Error while sending status update to display task: {:?}", e); - } - } - }, - } - Ok::<_, eyre::Report>(Some(sanitized_mapping)) - } + .values() + .map(|c| { + let clone = Arc::clone(c); + async move { clone.init().await } }) .collect::>(); - // TODO: do we want to introduce a timeout here? - self.tn_map = stream::iter(load_tool) - .map(|async_closure| tokio::task::spawn(async_closure)) - .buffer_unordered(20) - .collect::>() - .await - .into_iter() - .filter_map(|r| r.ok()) - .filter_map(|r| r.ok()) - .flatten() - .flatten() - .collect::>(); - drop(tx); - if let Some(display_task) = display_task { - if let Err(e) = display_task.join() { - error!("Error while joining status display task: {:?}", e); - } - } - let tool_specs = { - let mutex = - Arc::try_unwrap(tool_specs).map_err(|e| eyre::eyre!("Error unwrapping arc for tool specs {:?}", e))?; - mutex.into_inner() + let initial_poll = stream::iter(load_tools) + .map(|async_closure| tokio::spawn(async_closure)) + .buffer_unordered(20); + tokio::spawn(async move { + initial_poll.collect::>().await; + }); + // We need to cast it to erase the type otherwise the compiler will default to static + // dispatch, which would result in an error of inconsistent match arm return type. + let display_fut: Pin>> = match display_task { + Some(display_task) => { + let fut = async move { + if let Err(e) = display_task.await { + error!("Error while joining status display task: {:?}", e); + } + }; + Box::pin(fut) + }, + None => Box::pin(future::pending()), }; - // caching the tool names for skim operations - for tool_name in tool_specs.keys() { - if !self.tn_map.contains_key(tool_name) { - self.tn_map.insert(tool_name.clone(), tool_name.clone()); + let timeout_fut: Pin>> = if self.clients.is_empty() { + // If there is no server loaded, we want to resolve immediately + Box::pin(future::ready(())) + } else if self.is_interactive { + let init_timeout = database + .settings + .get_int(Setting::McpInitTimeout) + .map_or(5000_u64, |s| s as u64); + Box::pin(tokio::time::sleep(std::time::Duration::from_millis(init_timeout))) + } else { + Box::pin(future::pending()) + }; + tokio::select! { + _ = display_fut => {}, + _ = timeout_fut => { + if let Some(tx) = tx { + let _ = tx.send(LoadingMsg::Terminate); + } + }, + _ = ctrl_c() => { + if self.is_interactive { + if let Some(tx) = tx { + let _ = tx.send(LoadingMsg::Terminate); + } + } else { + return Err(eyre::eyre!("User interrupted mcp server loading in non-interactive mode. Ending.")); + } } } - self.schema = tool_specs.clone(); - Ok(tool_specs) + self.update().await; + Ok(self.schema.clone()) } pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { @@ -767,6 +866,49 @@ impl ToolManager { }) } + /// Updates tool managers various states with new information + pub async fn update(&mut self) { + // A hashmap of + let mut tool_specs = HashMap::::new(); + let new_tools = { + let mut new_tool_specs = self.new_tool_specs.lock().await; + new_tool_specs.drain().fold(HashMap::new(), |mut acc, (k, v)| { + acc.insert(k, v); + acc + }) + }; + let mut updated_servers = HashSet::::new(); + for (_server_name, (tool_name_map, specs)) in new_tools { + // In a populated tn map (i.e. a partially initialized or outdated fleet of servers) there + // will be incoming tools with names that are already in the tn map, we will be writing + // over them (perhaps with the same information that they already had), and that's okay. + // In an event where a server has removed tools, the tools that are no longer available + // will linger in this map. This is also okay to not clean up as it does not affect the + // look up of tool names that are still active. + for (k, v) in tool_name_map { + self.tn_map.insert(k, v); + } + if let Some(spec) = specs.first() { + updated_servers.insert(spec.tool_origin.clone()); + } + for spec in specs { + tool_specs.insert(spec.name.clone(), spec); + } + } + // Caching the tool names for skim operations + for tool_name in tool_specs.keys() { + if !self.tn_map.contains_key(tool_name) { + self.tn_map.insert(tool_name.clone(), tool_name.clone()); + } + } + // Update schema + // As we are writing over the ensemble of tools in a given server, we will need to first + // remove everything that it has. + self.schema + .retain(|_tool_name, spec| !updated_servers.contains(&spec.tool_origin)); + self.schema.extend(tool_specs); + } + #[allow(clippy::await_holding_lock)] pub async fn get_prompt(&self, get_command: PromptsGetCommand) -> Result { let (server_name, prompt_name) = match get_command.params.name.split_once('/') { @@ -927,6 +1069,131 @@ impl ToolManager { ); Ok(()) } + + pub async fn pending_clients(&self) -> Vec { + self.pending_clients.read().await.iter().cloned().collect::>() + } +} + +#[inline] +fn process_tool_specs( + conversation_id: &str, + server_name: &str, + is_in_display: bool, + specs: &mut Vec, + tn_map: &mut HashMap, + regex: &Arc, + telemetry: &TelemetryThread, +) -> Option { + // Each mcp server might have multiple tools. + // To avoid naming conflicts we are going to namespace it. + // This would also help us locate which mcp server to call the tool from. + let mut out_of_spec_tool_names = Vec::::new(); + let mut hasher = DefaultHasher::new(); + let number_of_tools = specs.len(); + // Sanitize tool names to ensure they comply with the naming requirements: + // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use + // it as is + // 2. Otherwise, remove invalid characters and handle special cases: + // - Remove namespace delimiters + // - Ensure the name starts with an alphabetic character + // - Generate a hash-based name if the sanitized result is empty + // This ensures all tool names are valid identifiers that can be safely used in the system + // If after all of the aforementioned modification the combined tool + // name we have exceeds a length of 64, we surface it as an error + for spec in specs { + let sn = if !regex.is_match(&spec.name) { + let mut sn = sanitize_name(spec.name.clone(), regex, &mut hasher); + while tn_map.contains_key(&sn) { + sn.push('1'); + } + sn + } else { + spec.name.clone() + }; + let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); + if full_name.len() > 64 { + out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone())); + continue; + } else if spec.description.is_empty() { + out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone())); + continue; + } + if sn != spec.name { + tn_map.insert( + full_name.clone(), + format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name), + ); + } + spec.name = full_name; + spec.tool_origin = ToolOrigin::McpServer(server_name.to_string()); + } + // Send server load success metric datum + let conversation_id = conversation_id.to_string(); + let _ = telemetry.send_mcp_server_init(conversation_id, None, number_of_tools); + // Tool name translation. This is beyond of the scope of what is + // considered a "server load". Reasoning being: + // - Failures here are not related to server load + // - There is not a whole lot we can do with this data + let loading_msg = if !out_of_spec_tool_names.is_empty() { + let msg = out_of_spec_tool_names.iter().fold( + String::from( + "The following tools are out of spec. They will be excluded from the list of available tools:\n", + ), + |mut acc, name| { + let (tool_name, msg) = match name { + OutOfSpecName::TooLong(tool_name) => ( + tool_name.as_str(), + "tool name exceeds max length of 64 when combined with server name", + ), + OutOfSpecName::IllegalChar(tool_name) => ( + tool_name.as_str(), + "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$", + ), + OutOfSpecName::EmptyDescription(tool_name) => { + (tool_name.as_str(), "tool schema contains empty description") + }, + }; + acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); + acc + }, + ); + error!( + "Server {} finished loading with the following error: \n{}", + server_name, msg + ); + if is_in_display { + Some(LoadingMsg::Warn { + name: server_name.to_string(), + msg: eyre::eyre!(msg), + }) + } else { + None + } + // TODO: if no tools are valid, we need to offload the server + // from the fleet (i.e. kill the server) + } else if !tn_map.is_empty() { + let warn = tn_map.iter().fold( + String::from("The following tool names are changed:\n"), + |mut acc, (k, v)| { + acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); + acc + }, + ); + if is_in_display { + Some(LoadingMsg::Warn { + name: server_name.to_string(), + msg: eyre::eyre!(warn), + }) + } else { + None + } + } else if is_in_display { + Some(LoadingMsg::Done(server_name.to_string())) + } else { + None + }; + loading_msg } fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { @@ -993,7 +1260,7 @@ fn queue_init_message( } else { queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?; } - Ok(queue!( + queue!( output, style::SetForegroundColor(style::Color::Blue), style::Print(format!(" {}", complete)), @@ -1002,11 +1269,22 @@ fn queue_init_message( style::SetForegroundColor(style::Color::Blue), style::Print(format!("{} ", total)), style::ResetColor, - style::Print("mcp servers initialized\n"), - )?) + style::Print("mcp servers initialized."), + )?; + if total > complete + failed { + queue!( + output, + style::SetForegroundColor(style::Color::Blue), + style::Print(" ctrl-c "), + style::ResetColor, + style::Print("to start chatting now") + )?; + } + Ok(queue!(output, style::Print("\n"))?) } fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + use crate::util::CHAT_BINARY_NAME; Ok(queue!( output, style::SetForegroundColor(style::Color::Red), @@ -1017,7 +1295,9 @@ fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut style::Print(" has failed to load:\n- "), style::Print(fail_load_msg), style::Print("\n"), - style::Print("- run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog for detail\n"), + style::Print(format!( + "- run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n" + )), style::ResetColor, )?) } @@ -1036,6 +1316,32 @@ fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) - )?) } +fn queue_incomplete_load_message( + complete: usize, + total: usize, + msg: &eyre::Report, + output: &mut impl Write, +) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠"), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!(" {}", complete)), + style::ResetColor, + style::Print(" of "), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!("{} ", total)), + style::ResetColor, + style::Print("mcp servers initialized."), + style::ResetColor, + // We expect the message start with a newline + style::Print(" Servers still loading:"), + style::Print(msg), + style::ResetColor, + )?) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 43580fecea..95fe96d01c 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -15,10 +15,7 @@ use serde::{ use tokio::sync::RwLock; use tracing::warn; -use super::{ - InvokeOutput, - ToolSpec, -}; +use super::InvokeOutput; use crate::cli::chat::CONTINUATION_LINE; use crate::cli::chat::token_counter::TokenCounter; use crate::mcp_client::{ @@ -27,6 +24,7 @@ use crate::mcp_client::{ JsonRpcResponse, JsonRpcStdioTransport, MessageContent, + Messenger, PromptGet, ServerCapabilities, StdioTransport, @@ -87,34 +85,31 @@ impl CustomToolClient { }) } - pub async fn init(&self) -> Result<(String, Vec)> { + pub async fn init(&self) -> Result<()> { match self { CustomToolClient::Stdio { client, - server_name, server_capabilities, + .. } => { + if let Some(messenger) = &client.messenger { + let _ = messenger.send_init_msg().await; + } // We'll need to first initialize. This is the handshake every client and server // needs to do before proceeding to anything else - let init_resp = client.init().await?; + let cap = client.init().await?; // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 // So don't worry about the tidiness for now - let is_tool_supported = init_resp - .get("result") - .is_some_and(|r| r.get("capabilities").is_some_and(|cap| cap.get("tools").is_some())); - server_capabilities.write().await.replace(init_resp); - // Assuming a shape of return as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#listing-tools - let tools = if is_tool_supported { - // And now we make the server tell us what tools they have - let resp = client.request("tools/list", None).await?; - match resp.result.and_then(|r| r.get("tools").cloned()) { - Some(value) => serde_json::from_value::>(value)?, - None => Default::default(), - } - } else { - Default::default() - }; - Ok((server_name.clone(), tools)) + server_capabilities.write().await.replace(cap); + Ok(()) + }, + } + } + + pub fn assign_messenger(&mut self, messenger: Box) { + match self { + CustomToolClient::Stdio { client, .. } => { + client.messenger = Some(messenger); }, } } @@ -177,9 +172,15 @@ impl CustomTool { pub async fn invoke(&self, _ctx: &Context, _updates: &mut impl Write) -> Result { // Assuming a response shape as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools let resp = self.client.request(self.method.as_str(), self.params.clone()).await?; - let result = resp - .result - .ok_or(eyre::eyre!("{} invocation failed to produce a result", self.name))?; + let result = match resp.result { + Some(result) => result, + None => { + let failure = resp.error.map_or("Unknown error encountered".to_string(), |err| { + serde_json::to_string(&err).unwrap_or_default() + }); + return Err(eyre::eyre!(failure)); + }, + }; match serde_json::from_value::(result.clone()) { Ok(mut de_result) => { diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index 3174de13ff..469ba90855 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -122,30 +122,33 @@ pub struct ToolPermission { /// Tools that do not have an associated ToolPermission should use /// their default logic to determine to permission. pub struct ToolPermissions { + // We need this field for any stragglers + pub trust_all: bool, pub permissions: HashMap, } impl ToolPermissions { pub fn new(capacity: usize) -> Self { Self { + trust_all: false, permissions: HashMap::with_capacity(capacity), } } pub fn is_trusted(&self, tool_name: &str) -> bool { - self.permissions.get(tool_name).is_some_and(|perm| perm.trusted) + self.trust_all || self.permissions.get(tool_name).is_some_and(|perm| perm.trusted) } /// Returns a label to describe the permission status for a given tool. pub fn display_label(&self, tool_name: &str) -> String { - if self.has(tool_name) { + if self.has(tool_name) || self.trust_all { if self.is_trusted(tool_name) { format!(" {}", "trusted".dark_green().bold()) } else { format!(" {}", "not trusted".dark_grey()) } } else { - Self::default_permission_label(tool_name) + self.default_permission_label(tool_name) } } @@ -155,15 +158,18 @@ impl ToolPermissions { } pub fn untrust_tool(&mut self, tool_name: &str) { + self.trust_all = false; self.permissions .insert(tool_name.to_string(), ToolPermission { trusted: false }); } pub fn reset(&mut self) { + self.trust_all = false; self.permissions.clear(); } pub fn reset_tool(&mut self, tool_name: &str) { + self.trust_all = false; self.permissions.remove(tool_name); } @@ -174,7 +180,7 @@ impl ToolPermissions { /// Provide default permission labels for the built-in set of tools. /// Unknown tools are assumed to be "Per-request" // This "static" way avoids needing to construct a tool instance. - fn default_permission_label(tool_name: &str) -> String { + fn default_permission_label(&self, tool_name: &str) -> String { let label = match tool_name { "fs_read" => "trusted".dark_green().bold(), "fs_write" => "not trusted".dark_grey(), @@ -182,6 +188,7 @@ impl ToolPermissions { "use_aws" => "trust read-only commands".dark_grey(), "report_issue" => "trusted".dark_green().bold(), "thinking" => "trusted (prerelease)".dark_green().bold(), + _ if self.trust_all => "trusted".dark_grey().bold(), _ => "not trusted".dark_grey(), }; diff --git a/crates/chat-cli/src/database/settings.rs b/crates/chat-cli/src/database/settings.rs index eb14c4c684..b2beaf0887 100644 --- a/crates/chat-cli/src/database/settings.rs +++ b/crates/chat-cli/src/database/settings.rs @@ -28,6 +28,7 @@ pub enum Setting { ChatEnableNotifications, ApiCodeWhispererService, ApiQService, + McpInitTimeout, } impl AsRef for Setting { @@ -44,6 +45,7 @@ impl AsRef for Setting { Self::ChatEnableNotifications => "chat.enableNotifications", Self::ApiCodeWhispererService => "api.codewhisperer.service", Self::ApiQService => "api.q.service", + Self::McpInitTimeout => "mcp.initTimeout", } } } @@ -70,6 +72,7 @@ impl TryFrom<&str> for Setting { "chat.enableNotifications" => Ok(Self::ChatEnableNotifications), "api.codewhisperer.service" => Ok(Self::ApiCodeWhispererService), "api.q.service" => Ok(Self::ApiQService), + "mcp.initTimeout" => Ok(Self::McpInitTimeout), _ => Err(DatabaseError::InvalidSetting(value.to_string())), } } diff --git a/crates/chat-cli/src/lib.rs b/crates/chat-cli/src/lib.rs index 2b584b4c47..d8bfa3209c 100644 --- a/crates/chat-cli/src/lib.rs +++ b/crates/chat-cli/src/lib.rs @@ -1,3 +1,4 @@ +#![cfg(not(test))] //! This lib.rs is only here for testing purposes. //! `test_mcp_server/test_server.rs` is declared as a separate binary and would need a way to //! reference types defined inside of this crate, hence the export. diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 9dc9f6bc98..9012f32e1a 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -11,9 +11,7 @@ use std::sync::{ }; use std::time::Duration; -#[cfg(unix)] use nix::sys::signal::Signal; -#[cfg(unix)] use nix::unistd::Pid; use serde::{ Deserialize, @@ -23,31 +21,32 @@ use thiserror::Error; use tokio::time; use tokio::time::error::Elapsed; -use crate::mcp_client::transport::base_protocol::{ +use super::transport::base_protocol::{ JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcVersion, }; -use crate::mcp_client::transport::stdio::JsonRpcStdioTransport; -use crate::mcp_client::transport::{ +use super::transport::stdio::JsonRpcStdioTransport; +use super::transport::{ self, Transport, TransportError, }; -use crate::mcp_client::{ +use super::{ JsonRpcResponse, Listener as _, LogListener, + Messenger, PaginationSupportedOps, PromptGet, PromptsListResult, ResourceTemplatesListResult, ResourcesListResult, + ServerCapabilities, ToolsListResult, }; -pub type ServerCapabilities = serde_json::Value; pub type ClientInfo = serde_json::Value; pub type StdioTransport = JsonRpcStdioTransport; @@ -83,6 +82,7 @@ pub struct ClientConfig { pub env: Option>, } +#[allow(dead_code)] #[derive(Debug, Error)] pub enum ClientError { #[error(transparent)] @@ -97,10 +97,18 @@ pub enum ClientError { source: tokio::time::error::Elapsed, context: String, }, + #[error("Unexpected msg type encountered")] + UnexpectedMsgType, #[error("{0}")] NegotiationError(String), #[error("Failed to obtain process id")] MissingProcessId, + #[error("Invalid path received")] + InvalidPath, + #[error("{0}")] + ProcessKillError(String), + #[error("{0}")] + PoisonError(String), } impl From<(tokio::time::error::Elapsed, String)> for ClientError { @@ -114,10 +122,11 @@ pub struct Client { server_name: String, transport: Arc, timeout: u64, - #[cfg(unix)] server_process_id: Option, client_info: serde_json::Value, current_id: Arc, + pub messenger: Option>, + // TODO: move this to tool manager that way all the assets are treated equally pub prompt_gets: Arc>>, pub is_prompts_out_of_date: Arc, } @@ -130,10 +139,10 @@ impl Clone for Client { timeout: self.timeout, // Note that we cannot have an id for the clone because we would kill the original // process when we drop the clone - #[cfg(unix)] server_process_id: None, client_info: self.client_info.clone(), current_id: self.current_id.clone(), + messenger: None, prompt_gets: self.prompt_gets.clone(), is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), } @@ -156,11 +165,8 @@ impl Client { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) + .process_group(0) .envs(std::env::vars()); - - #[cfg(unix)] - command.process_group(0); - if let Some(env) = env { for (env_name, env_value) in env { command.env(env_name, env_value); @@ -168,32 +174,29 @@ impl Client { } command.args(args).spawn()? }; - - #[cfg(unix)] - let server_process_id = Some(Pid::from_raw( - child - .id() - .ok_or(ClientError::MissingProcessId)? + let server_process_id = child.id().ok_or(ClientError::MissingProcessId)?; + #[allow(clippy::map_err_ignore)] + let server_process_id = Pid::from_raw( + server_process_id .try_into() - .map_err(|_err| ClientError::MissingProcessId)?, - )); - + .map_err(|_| ClientError::MissingProcessId)?, + ); + let server_process_id = Some(server_process_id); let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?); Ok(Self { server_name, transport, timeout, - #[cfg(unix)] server_process_id, client_info, current_id: Arc::new(AtomicU64::new(0)), + messenger: None, prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), }) } } -#[cfg(unix)] impl Drop for Client where T: Transport, @@ -213,66 +216,14 @@ where { /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization /// - /// Also done is the spawn of a background task that constantly listens for incoming messages - /// from the server. + /// Also done are the following: + /// - Spawns task for listening to server driven workflows + /// - Spawns tasks to ask for relevant info such as tools and prompts in accordance to server + /// capabilities received pub async fn init(&self) -> Result { let transport_ref = self.transport.clone(); let server_name = self.server_name.clone(); - tokio::spawn(async move { - let mut listener = transport_ref.get_listener(); - loop { - match listener.recv().await { - Ok(msg) => { - match msg { - JsonRpcMessage::Request(_req) => {}, - JsonRpcMessage::Notification(notif) => { - let JsonRpcNotification { method, params, .. } = notif; - if method.as_str() == "notifications/message" || method.as_str() == "message" { - let level = params - .as_ref() - .and_then(|p| p.get("level")) - .and_then(|v| serde_json::to_string(v).ok()); - let data = params - .as_ref() - .and_then(|p| p.get("data")) - .and_then(|v| serde_json::to_string(v).ok()); - if let (Some(level), Some(data)) = (level, data) { - match level.to_lowercase().as_str() { - "error" => { - tracing::error!(target: "mcp", "{}: {}", server_name, data); - }, - "warn" => { - tracing::warn!(target: "mcp", "{}: {}", server_name, data); - }, - "info" => { - tracing::info!(target: "mcp", "{}: {}", server_name, data); - }, - "debug" => { - tracing::debug!(target: "mcp", "{}: {}", server_name, data); - }, - "trace" => { - tracing::trace!(target: "mcp", "{}: {}", server_name, data); - }, - _ => {}, - } - } - } - }, - JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ - }, - } - }, - Err(e) => { - tracing::error!("Background listening thread for client {}: {:?}", server_name, e); - }, - } - } - }); - - let transport_ref = self.transport.clone(); - let server_name = self.server_name.clone(); - // Spawning a task to listen and log stderr output tokio::spawn(async move { let mut log_listener = transport_ref.get_log_listener(); @@ -296,63 +247,126 @@ where let client_cap = ClientCapabilities::from(self.client_info.clone()); serde_json::json!(client_cap) }); - let server_capabilities = self.request("initialize", init_params).await?; - if let Err(e) = examine_server_capabilities(&server_capabilities) { + let init_resp = self.request("initialize", init_params).await?; + if let Err(e) = examine_server_capabilities(&init_resp) { return Err(ClientError::NegotiationError(format!( "Client {} has failed to negotiate server capabilities with server: {:?}", self.server_name, e ))); } + let cap = { + let result = init_resp.result.ok_or(ClientError::NegotiationError(format!( + "Server {} init resp is missing result", + self.server_name + )))?; + let cap = result + .get("capabilities") + .ok_or(ClientError::NegotiationError(format!( + "Server {} init resp result is missing capabilities", + self.server_name + )))? + .clone(); + serde_json::from_value::(cap)? + }; self.notify("initialized", None).await?; // TODO: group this into examine_server_capabilities // Prefetch prompts in the background. We should only do this after the server has been // initialized - if let Some(res) = &server_capabilities.result { - if let Some(cap) = res.get("capabilities") { - if cap.get("prompts").is_some() { - self.is_prompts_out_of_date.store(true, Ordering::Relaxed); - let client_ref = (*self).clone(); - tokio::spawn(async move { - let Ok(resp) = client_ref.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client_ref.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); - return; - }; - let Some(prompts) = result.get("prompts") else { - tracing::warn!( - "Prompt list query result contained no field named prompts for {0}", - client_ref.server_name - ); - return; - }; - let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { - tracing::error!( - "Prompt list query deserialization failed for {0}", - client_ref.server_name - ); - return; - }; - let Ok(mut lock) = client_ref.prompt_gets.write() else { - tracing::error!( - "Failed to obtain write lock for prompt list query for {0}", - client_ref.server_name - ); - return; - }; - for prompt in prompts { - let name = prompt.name.clone(); - lock.insert(name, prompt); + if cap.prompts.is_some() { + self.is_prompts_out_of_date.store(true, Ordering::Relaxed); + let client_ref = (*self).clone(); + let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); + tokio::spawn(async move { + fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; + }); + } + if cap.tools.is_some() { + let client_ref = (*self).clone(); + let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); + tokio::spawn(async move { + fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()).await; + }); + } + + let transport_ref = self.transport.clone(); + let server_name = self.server_name.clone(); + let messenger_ref = self.messenger.as_ref().map(|m| m.duplicate()); + let client_ref = (*self).clone(); + + let prompts_list_changed_supported = cap.prompts.as_ref().is_some_and(|p| p.get("listChanged").is_some()); + let tools_list_changed_supported = cap.tools.as_ref().is_some_and(|t| t.get("listChanged").is_some()); + tokio::spawn(async move { + let mut listener = transport_ref.get_listener(); + loop { + match listener.recv().await { + Ok(msg) => { + match msg { + JsonRpcMessage::Request(_req) => {}, + JsonRpcMessage::Notification(notif) => { + let JsonRpcNotification { method, params, .. } = notif; + match method.as_str() { + "notifications/message" | "message" => { + let level = params + .as_ref() + .and_then(|p| p.get("level")) + .and_then(|v| serde_json::to_string(v).ok()); + let data = params + .as_ref() + .and_then(|p| p.get("data")) + .and_then(|v| serde_json::to_string(v).ok()); + if let (Some(level), Some(data)) = (level, data) { + match level.to_lowercase().as_str() { + "error" => { + tracing::error!(target: "mcp", "{}: {}", server_name, data); + }, + "warn" => { + tracing::warn!(target: "mcp", "{}: {}", server_name, data); + }, + "info" => { + tracing::info!(target: "mcp", "{}: {}", server_name, data); + }, + "debug" => { + tracing::debug!(target: "mcp", "{}: {}", server_name, data); + }, + "trace" => { + tracing::trace!(target: "mcp", "{}: {}", server_name, data); + }, + _ => {}, + } + } + }, + "notifications/prompts/list_changed" | "prompts/list_changed" + if prompts_list_changed_supported => + { + // TODO: after we have moved the prompts to the tool + // manager we follow the same workflow as the list changed + // for tools + fetch_prompts_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) + .await; + client_ref.is_prompts_out_of_date.store(true, Ordering::Release); + }, + "notifications/tools/list_changed" | "tools/list_changed" + if tools_list_changed_supported => + { + fetch_tools_and_notify_with_messenger(&client_ref, messenger_ref.as_ref()) + .await; + }, + _ => {}, + } + }, + JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ + }, } - }); + }, + Err(e) => { + tracing::error!("Background listening thread for client {}: {:?}", server_name, e); + }, } } - } + }); - Ok(serde_json::to_value(server_capabilities)?) + Ok(cap) } /// Sends a request to the server associated. @@ -403,13 +417,13 @@ where loop { let result = current_resp.result.as_ref().cloned().unwrap(); let mut list: Vec = match ops { - PaginationSupportedOps::Resources => { + PaginationSupportedOps::ResourcesList => { let ResourcesListResult { resources: list, .. } = serde_json::from_value::(result) .map_err(ClientError::Serialization)?; list }, - PaginationSupportedOps::ResourceTemplates => { + PaginationSupportedOps::ResourceTemplatesList => { let ResourceTemplatesListResult { resource_templates: list, .. @@ -417,13 +431,13 @@ where .map_err(ClientError::Serialization)?; list }, - PaginationSupportedOps::Prompts => { + PaginationSupportedOps::PromptsList => { let PromptsListResult { prompts: list, .. } = serde_json::from_value::(result) .map_err(ClientError::Serialization)?; list }, - PaginationSupportedOps::Tools => { + PaginationSupportedOps::ToolsList => { let ToolsListResult { tools: list, .. } = serde_json::from_value::(result) .map_err(ClientError::Serialization)?; list @@ -490,10 +504,6 @@ where ) } - pub async fn shutdown(&self) -> Result<(), ClientError> { - Ok(self.transport.shutdown().await?) - } - fn get_id(&self) -> u64 { self.current_id.fetch_add(1, Ordering::SeqCst) } @@ -514,6 +524,85 @@ fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientEr Ok(()) } +// TODO: after we move prompts to tool manager, use the messenger to notify the listener spawned by +// tool manager to update its own field. Currently this function does not make use of the +// messesnger. +#[allow(clippy::borrowed_box)] +async fn fetch_prompts_and_notify_with_messenger(client: &Client, _messenger: Option<&Box>) +where + T: Transport, +{ + let Ok(resp) = client.request("prompts/list", None).await else { + tracing::error!("Prompt list query failed for {0}", client.server_name); + return; + }; + let Some(result) = resp.result else { + tracing::warn!("Prompt list query returned no result for {0}", client.server_name); + return; + }; + let Some(prompts) = result.get("prompts") else { + tracing::warn!( + "Prompt list query result contained no field named prompts for {0}", + client.server_name + ); + return; + }; + let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { + tracing::error!("Prompt list query deserialization failed for {0}", client.server_name); + return; + }; + let Ok(mut lock) = client.prompt_gets.write() else { + tracing::error!( + "Failed to obtain write lock for prompt list query for {0}", + client.server_name + ); + return; + }; + lock.clear(); + for prompt in prompts { + let name = prompt.name.clone(); + lock.insert(name, prompt); + } +} + +#[allow(clippy::borrowed_box)] +async fn fetch_tools_and_notify_with_messenger(client: &Client, messenger: Option<&Box>) +where + T: Transport, +{ + // TODO: decouple pagination logic from request and have page fetching logic here + // instead + let resp = match client.request("tools/list", None).await { + Ok(resp) => resp, + Err(e) => { + tracing::error!("Failed to retrieve tool list from {}: {:?}", client.server_name, e); + return; + }, + }; + if let Some(error) = resp.error { + let msg = format!("Failed to retrieve tool list for {}: {:?}", client.server_name, error); + tracing::error!("{}", &msg); + return; + } + let Some(result) = resp.result else { + tracing::error!("Tool list response from {} is missing result", client.server_name); + return; + }; + let tool_list_result = match serde_json::from_value::(result) { + Ok(result) => result, + Err(e) => { + tracing::error!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); + return; + }, + }; + if let Some(messenger) = messenger { + let _ = messenger + .send_tools_list_result(tool_list_result) + .await + .map_err(|e| tracing::error!("Failed to send tool result through messenger {:?}", e)); + } +} + #[cfg(test)] mod tests { use std::path::PathBuf; @@ -592,11 +681,11 @@ mod tests { let (res_one, res_two) = tokio::join!( time::timeout( - time::Duration::from_secs(5), + time::Duration::from_secs(10), test_client_routine(&mut client_one, serde_json::json!(client_one_cap)) ), time::timeout( - time::Duration::from_secs(5), + time::Duration::from_secs(10), test_client_routine(&mut client_two, serde_json::json!(client_two_cap)) ) ); @@ -606,6 +695,7 @@ mod tests { assert!(res_two.is_ok()); } + #[allow(clippy::await_holding_lock)] async fn test_client_routine( client: &mut Client, cap_sent: serde_json::Value, @@ -681,6 +771,7 @@ mod tests { .await .expect("Mock prompt prep failed"); let prompts_recvd = client.request("prompts/list", None).await.expect("List prompts failed"); + client.is_prompts_out_of_date.store(false, Ordering::Release); assert!(are_json_values_equal( prompts_recvd .result @@ -690,6 +781,41 @@ mod tests { &mock_prompts_for_verify )); + // Test prompts list changed + let fake_prompt_names = ["code_review_four", "code_review_five", "code_review_six"]; + let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); + let mock_prompts_prep_param = mock_result_prompts + .iter() + .zip(fake_prompt_names.iter()) + .map(|(v, n)| { + serde_json::json!({ + "key": (*n).to_string(), + "value": v + }) + }) + .collect::>(); + let mock_prompts_prep_param = + serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); + let _ = client + .request("store_mock_prompts", Some(mock_prompts_prep_param)) + .await + .expect("Mock new prompt request failed"); + // After we send the signal for the server to clear prompts, we should be receiving signal + // to fetch for new prompts, after which we should be getting no prompts. + let is_prompts_out_of_date = client.is_prompts_out_of_date.clone(); + let wait_for_new_prompts = async move { + while !is_prompts_out_of_date.load(Ordering::Acquire) { + tokio::time::sleep(time::Duration::from_millis(100)).await; + } + }; + time::timeout(time::Duration::from_secs(5), wait_for_new_prompts) + .await + .expect("Timed out while waiting for new prompts"); + let new_prompts = client.prompt_gets.read().expect("Failed to read new prompts"); + for k in new_prompts.keys() { + assert!(fake_prompt_names.contains(&k.as_str())); + } + // Test env var inclusion let env_vars = client.request("get_env_vars", None).await.expect("Get env vars failed"); let env_one = env_vars @@ -709,8 +835,6 @@ mod tests { assert_eq!(env_one_as_str, "\"1\"".to_string()); assert_eq!(env_two_as_str, "\"2\"".to_string()); - let shutdown_result = client.shutdown().await; - assert!(shutdown_result.is_ok()); Ok(()) } diff --git a/crates/chat-cli/src/mcp_client/error.rs b/crates/chat-cli/src/mcp_client/error.rs new file mode 100644 index 0000000000..01f77cfa8b --- /dev/null +++ b/crates/chat-cli/src/mcp_client/error.rs @@ -0,0 +1,66 @@ +/// Error codes as defined in the MCP protocol. +/// +/// These error codes are based on the JSON-RPC 2.0 specification with additional +/// MCP-specific error codes in the -32000 to -32099 range. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +pub enum ErrorCode { + /// Invalid JSON was received by the server. + /// An error occurred on the server while parsing the JSON text. + ParseError = -32700, + + /// The JSON sent is not a valid Request object. + InvalidRequest = -32600, + + /// The method does not exist / is not available. + MethodNotFound = -32601, + + /// Invalid method parameter(s). + InvalidParams = -32602, + + /// Internal JSON-RPC error. + InternalError = -32603, + + /// Server has not been initialized. + /// This error is returned when a request is made before the server + /// has been properly initialized. + ServerNotInitialized = -32002, + + /// Unknown error code. + /// This error is returned when an error code is received that is not + /// recognized by the implementation. + Unknown = -32001, + + /// Request failed. + /// This error is returned when a request fails for a reason not covered + /// by other error codes. + RequestFailed = -32000, +} + +impl From for ErrorCode { + fn from(code: i32) -> Self { + match code { + -32700 => ErrorCode::ParseError, + -32600 => ErrorCode::InvalidRequest, + -32601 => ErrorCode::MethodNotFound, + -32602 => ErrorCode::InvalidParams, + -32603 => ErrorCode::InternalError, + -32002 => ErrorCode::ServerNotInitialized, + -32001 => ErrorCode::Unknown, + -32000 => ErrorCode::RequestFailed, + _ => ErrorCode::Unknown, + } + } +} + +impl From for i32 { + fn from(code: ErrorCode) -> Self { + code as i32 + } +} + +impl std::fmt::Display for ErrorCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} diff --git a/crates/chat-cli/src/mcp_client/facilitator_types.rs b/crates/chat-cli/src/mcp_client/facilitator_types.rs index 38d4aca280..87fbd79b27 100644 --- a/crates/chat-cli/src/mcp_client/facilitator_types.rs +++ b/crates/chat-cli/src/mcp_client/facilitator_types.rs @@ -5,21 +5,22 @@ use serde::{ use thiserror::Error; /// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination +#[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, PartialEq, Eq)] pub enum PaginationSupportedOps { - Resources, - ResourceTemplates, - Prompts, - Tools, + ResourcesList, + ResourceTemplatesList, + PromptsList, + ToolsList, } impl PaginationSupportedOps { pub fn as_key(&self) -> &str { match self { - PaginationSupportedOps::Resources => "resources", - PaginationSupportedOps::ResourceTemplates => "resourceTemplates", - PaginationSupportedOps::Prompts => "prompts", - PaginationSupportedOps::Tools => "tools", + PaginationSupportedOps::ResourcesList => "resources", + PaginationSupportedOps::ResourceTemplatesList => "resourceTemplates", + PaginationSupportedOps::PromptsList => "prompts", + PaginationSupportedOps::ToolsList => "tools", } } } @@ -29,10 +30,10 @@ impl TryFrom<&str> for PaginationSupportedOps { fn try_from(value: &str) -> Result { match value { - "resources/list" => Ok(PaginationSupportedOps::Resources), - "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplates), - "prompts/list" => Ok(PaginationSupportedOps::Prompts), - "tools/list" => Ok(PaginationSupportedOps::Tools), + "resources/list" => Ok(PaginationSupportedOps::ResourcesList), + "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplatesList), + "prompts/list" => Ok(PaginationSupportedOps::PromptsList), + "tools/list" => Ok(PaginationSupportedOps::ToolsList), _ => Err(OpsConversionError::InvalidMethod), } } @@ -227,3 +228,21 @@ pub struct Resource { /// Resource contents pub contents: ResourceContents, } + +/// Represents the capabilities supported by a Model Context Protocol server +/// This is the "capabilities" field in the result of a response for init +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerCapabilities { + /// Configuration for server logging capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub logging: Option, + /// Configuration for prompt-related capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub prompts: Option, + /// Configuration for resource management capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub resources: Option, + /// Configuration for tool integration capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option, +} diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs new file mode 100644 index 0000000000..efd49617ab --- /dev/null +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -0,0 +1,81 @@ +use thiserror::Error; + +use super::{ + PromptsListResult, + ResourceTemplatesListResult, + ResourcesListResult, + ToolsListResult, +}; + +/// An interface that abstracts the implementation for information delivery from client and its +/// consumer. It is through this interface secondary information (i.e. information that are needed +/// to make requests to mcp servers) are obtained passively. Consumers of client can of course +/// choose to "actively" retrieve these information via explicitly making these requests. +#[allow(dead_code)] +#[async_trait::async_trait] +pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { + /// Sends the result of a tools list operation to the consumer + /// This function is used to deliver information about available tools + async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError>; + + /// Sends the result of a prompts list operation to the consumer + /// This function is used to deliver information about available prompts + async fn send_prompts_list_result(&self, result: PromptsListResult) -> Result<(), MessengerError>; + + /// Sends the result of a resources list operation to the consumer + /// This function is used to deliver information about available resources + async fn send_resources_list_result(&self, result: ResourcesListResult) -> Result<(), MessengerError>; + + /// Sends the result of a resource templates list operation to the consumer + /// This function is used to deliver information about available resource templates + async fn send_resource_templates_list_result( + &self, + result: ResourceTemplatesListResult, + ) -> Result<(), MessengerError>; + + /// Signals to the orchestrator that a server has started initializing + async fn send_init_msg(&self) -> Result<(), MessengerError>; + + /// Creates a duplicate of the messenger object + /// This function is used to create a new instance of the messenger with the same configuration + fn duplicate(&self) -> Box; +} + +#[derive(Clone, Debug, Error)] +pub enum MessengerError { + #[error("{0}")] + Custom(String), +} + +#[derive(Clone, Debug)] +pub struct NullMessenger; + +#[async_trait::async_trait] +impl Messenger for NullMessenger { + async fn send_tools_list_result(&self, _result: ToolsListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_prompts_list_result(&self, _result: PromptsListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_resources_list_result(&self, _result: ResourcesListResult) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_resource_templates_list_result( + &self, + _result: ResourceTemplatesListResult, + ) -> Result<(), MessengerError> { + Ok(()) + } + + async fn send_init_msg(&self) -> Result<(), MessengerError> { + Ok(()) + } + + fn duplicate(&self) -> Box { + Box::new(NullMessenger) + } +} diff --git a/crates/chat-cli/src/mcp_client/mod.rs b/crates/chat-cli/src/mcp_client/mod.rs index 9c10970a30..51f8b178fd 100644 --- a/crates/chat-cli/src/mcp_client/mod.rs +++ b/crates/chat-cli/src/mcp_client/mod.rs @@ -1,79 +1,13 @@ -#![allow(dead_code)] - -mod client; -mod facilitator_types; -mod server; -mod transport; +pub mod client; +pub mod error; +pub mod facilitator_types; +pub mod messenger; +pub mod server; +pub mod transport; pub use client::*; pub use facilitator_types::*; +pub use messenger::*; #[allow(unused_imports)] pub use server::*; pub use transport::*; - -/// Error codes as defined in the MCP protocol. -/// -/// These error codes are based on the JSON-RPC 2.0 specification with additional -/// MCP-specific error codes in the -32000 to -32099 range. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(i32)] -pub enum McpError { - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - ParseError = -32700, - - /// The JSON sent is not a valid Request object. - InvalidRequest = -32600, - - /// The method does not exist / is not available. - MethodNotFound = -32601, - - /// Invalid method parameter(s). - InvalidParams = -32602, - - /// Internal JSON-RPC error. - InternalError = -32603, - - /// Server has not been initialized. - /// This error is returned when a request is made before the server - /// has been properly initialized. - ServerNotInitialized = -32002, - - /// Unknown error code. - /// This error is returned when an error code is received that is not - /// recognized by the implementation. - UnknownErrorCode = -32001, - - /// Request failed. - /// This error is returned when a request fails for a reason not covered - /// by other error codes. - RequestFailed = -32000, -} - -impl From for McpError { - fn from(code: i32) -> Self { - match code { - -32700 => McpError::ParseError, - -32600 => McpError::InvalidRequest, - -32601 => McpError::MethodNotFound, - -32602 => McpError::InvalidParams, - -32603 => McpError::InternalError, - -32002 => McpError::ServerNotInitialized, - -32001 => McpError::UnknownErrorCode, - -32000 => McpError::RequestFailed, - _ => McpError::UnknownErrorCode, - } - } -} - -impl From for i32 { - fn from(code: McpError) -> Self { - code as i32 - } -} - -impl std::fmt::Display for McpError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} diff --git a/crates/chat-cli/src/mcp_client/server.rs b/crates/chat-cli/src/mcp_client/server.rs index d51d19125e..7b320a2c6e 100644 --- a/crates/chat-cli/src/mcp_client/server.rs +++ b/crates/chat-cli/src/mcp_client/server.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] use std::collections::HashMap; use std::sync::atomic::{ AtomicBool, @@ -15,24 +16,22 @@ use tokio::io::{ }; use tokio::task::JoinHandle; -use crate::mcp_client::client::StdioTransport; -use crate::mcp_client::transport::base_protocol::{ +use super::Listener as _; +use super::client::StdioTransport; +use super::error::ErrorCode; +use super::transport::base_protocol::{ JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, }; -use crate::mcp_client::transport::stdio::JsonRpcStdioTransport; -use crate::mcp_client::transport::{ +use super::transport::stdio::JsonRpcStdioTransport; +use super::transport::{ JsonRpcVersion, Transport, TransportError, }; -use crate::mcp_client::{ - Listener as _, - McpError, -}; pub type Request = serde_json::Value; pub type Response = Option; @@ -110,20 +109,37 @@ where let current_id_clone = current_id.clone(); let request_sender = move |method: &str, params: Option| -> Result<(), ServerError> { let id = current_id_clone.fetch_add(1, Ordering::SeqCst); - let request = JsonRpcRequest { - jsonrpc: JsonRpcVersion::default(), - id, - method: method.to_owned(), - params, + let msg = match method.split_once("/") { + Some(("request", _)) => { + let request = JsonRpcRequest { + jsonrpc: JsonRpcVersion::default(), + id, + method: method.to_owned(), + params, + }; + let msg = JsonRpcMessage::Request(request.clone()); + #[allow(clippy::map_err_ignore)] + let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; + pending_request.insert(id, request); + Some(msg) + }, + Some(("notifications", _)) => { + let notif = JsonRpcNotification { + jsonrpc: JsonRpcVersion::default(), + method: method.to_owned(), + params, + }; + let msg = JsonRpcMessage::Notification(notif); + Some(msg) + }, + _ => None, }; - let msg = JsonRpcMessage::Request(request.clone()); - let transport = transport_clone.clone(); - tokio::task::spawn(async move { - let _ = transport.send(&msg).await; - }); - #[allow(clippy::map_err_ignore)] - let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; - pending_request.insert(id, request); + if let Some(msg) = msg { + let transport = transport_clone.clone(); + tokio::task::spawn(async move { + let _ = transport.send(&msg).await; + }); + } Ok(()) }; handler.register_send_request_callback(request_sender); @@ -179,7 +195,7 @@ async fn process_request( jsonrpc: JsonRpcVersion::default(), id, error: Some(JsonRpcError { - code: McpError::InvalidRequest.into(), + code: ErrorCode::InvalidRequest.into(), message: "Server has already been initialized".to_owned(), data: None, }), @@ -193,7 +209,7 @@ async fn process_request( jsonrpc: JsonRpcVersion::default(), id, error: Some(JsonRpcError { - code: McpError::InvalidRequest.into(), + code: ErrorCode::InvalidRequest.into(), message: "Invalid method for initialization (use request)".to_owned(), data: None, }), @@ -218,7 +234,7 @@ async fn process_request( jsonrpc: JsonRpcVersion::default(), id, error: Some(JsonRpcError { - code: McpError::InternalError.into(), + code: ErrorCode::InternalError.into(), message: "Error producing initialization response".to_owned(), data: None, }), @@ -242,7 +258,7 @@ async fn process_request( let resp = handler.handle_incoming(method, params).await.map_or_else( |error| { let err = JsonRpcError { - code: McpError::InternalError.into(), + code: ErrorCode::InternalError.into(), message: error.to_string(), data: None, }; @@ -280,7 +296,7 @@ async fn process_request( jsonrpc: JsonRpcVersion::default(), id, error: Some(JsonRpcError { - code: McpError::ServerNotInitialized.into(), + code: ErrorCode::ServerNotInitialized.into(), message: "Server has not been initialized".to_owned(), data: None, }), diff --git a/crates/chat-cli/src/mcp_client/transport/mod.rs b/crates/chat-cli/src/mcp_client/transport/mod.rs index f86fc498f3..f752b1675a 100644 --- a/crates/chat-cli/src/mcp_client/transport/mod.rs +++ b/crates/chat-cli/src/mcp_client/transport/mod.rs @@ -4,7 +4,7 @@ pub mod stdio; use std::fmt::Debug; pub use base_protocol::*; -pub use stdio::JsonRpcStdioTransport; +pub use stdio::*; use thiserror::Error; #[derive(Clone, Debug, Error)] @@ -31,6 +31,7 @@ impl From for TransportError { } } +#[allow(dead_code)] #[async_trait::async_trait] pub trait Transport: Send + Sync + Debug + 'static { /// Sends a message over the transport layer. diff --git a/crates/chat-cli/src/mcp_client/transport/stdio.rs b/crates/chat-cli/src/mcp_client/transport/stdio.rs index 270756f2d9..2505c39ddd 100644 --- a/crates/chat-cli/src/mcp_client/transport/stdio.rs +++ b/crates/chat-cli/src/mcp_client/transport/stdio.rs @@ -204,7 +204,12 @@ mod tests { }; use tokio::process::Command; - use super::*; + use super::{ + JsonRpcMessage, + JsonRpcStdioTransport, + Listener, + Transport, + }; // Helpers for testing fn create_test_message() -> JsonRpcMessage { diff --git a/crates/chat-cli/src/telemetry/mod.rs b/crates/chat-cli/src/telemetry/mod.rs index 430a327204..b49f4da592 100644 --- a/crates/chat-cli/src/telemetry/mod.rs +++ b/crates/chat-cli/src/telemetry/mod.rs @@ -36,6 +36,7 @@ pub use install_method::{ }; use tokio::sync::mpsc; use tokio::task::JoinHandle; +use tokio::time::error::Elapsed; use tracing::{ debug, error, @@ -75,6 +76,8 @@ pub enum TelemetryError { Join(#[from] tokio::task::JoinError), #[error(transparent)] Database(#[from] DatabaseError), + #[error(transparent)] + Timeout(#[from] Elapsed), } impl From for TelemetryError { @@ -159,7 +162,16 @@ impl TelemetryThread { pub async fn finish(self) -> Result<(), TelemetryError> { drop(self.tx); if let Some(handle) = self.handle { - handle.await?; + match tokio::time::timeout(std::time::Duration::from_millis(1000), handle).await { + Ok(result) => { + if let Err(e) = result { + return Err(TelemetryError::Join(e)); + } + }, + Err(_) => { + // Ignore timeout errors + }, + } } Ok(()) diff --git a/crates/chat-cli/test_mcp_server/test_server.rs b/crates/chat-cli/test_mcp_server/test_server.rs index 46c2f235ea..970157f96b 100644 --- a/crates/chat-cli/test_mcp_server/test_server.rs +++ b/crates/chat-cli/test_mcp_server/test_server.rs @@ -167,24 +167,17 @@ impl ServerRequestHandler for Handler { return Ok(None); } } else { - let first_key = self - .tool_spec_key_list - .lock() - .await + let tool_spec_key_list = self.tool_spec_key_list.lock().await; + let tool_spec = self.tool_spec.lock().await; + let first_key = tool_spec_key_list .first() .expect("First key missing from tool specs") .clone(); - let first_value = self - .tool_spec - .lock() - .await + let first_value = tool_spec .get(&first_key) .expect("First value missing from tool specs") .clone(); - let second_key = self - .tool_spec_key_list - .lock() - .await + let second_key = tool_spec_key_list .get(1) .expect("Second key missing from tool specs") .clone(); @@ -241,8 +234,11 @@ impl ServerRequestHandler for Handler { eprintln!("Failed to convert to mock specs from value"); return Ok(None); }; - let self_prompts = self.prompts.lock().await; + let mut self_prompts = self.prompts.lock().await; let mut self_prompt_key_list = self.prompt_key_list.lock().await; + let is_first_mock = self_prompts.is_empty(); + self_prompts.clear(); + self_prompt_key_list.clear(); let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| { let Some(key) = spec.get("key").cloned() else { return acc; @@ -255,9 +251,16 @@ impl ServerRequestHandler for Handler { acc.insert(key, spec.get("value").cloned()); acc }); + if !is_first_mock { + if let Some(sender) = &self.send_request { + let _ = sender("notifications/prompts/list_changed", None); + } + } Ok(None) }, "prompts/list" => { + // We expect this method to be called after the mock prompts have already been + // stored. self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed); if let Some(params) = params { if let Some(cursor) = params.get("cursor").cloned() { @@ -295,27 +298,12 @@ impl ServerRequestHandler for Handler { return Ok(None); } } else { - let first_key = self - .prompt_key_list - .lock() - .await - .first() - .expect("First key missing from prompts") - .clone(); - let first_value = self - .prompts - .lock() - .await - .get(&first_key) - .expect("First value missing from prompts") - .clone(); - let second_key = self - .prompt_key_list - .lock() - .await - .get(1) - .expect("Second key missing from prompts") - .clone(); + // If there is no parameter, this is the request to retrieve the first page + let prompt_key_list = self.prompt_key_list.lock().await; + let prompts = self.prompts.lock().await; + let first_key = prompt_key_list.first().expect("first key missing"); + let first_value = prompts.get(first_key).cloned().unwrap().unwrap(); + let second_key = prompt_key_list.get(1).expect("second key missing"); return Ok(Some(serde_json::json!({ "prompts": [first_value], "nextCursor": second_key