diff --git a/crates/chat-cli/src/cli/chat/command.rs b/crates/chat-cli/src/cli/chat/command.rs index f2a4262b60..5252e63aa1 100644 --- a/crates/chat-cli/src/cli/chat/command.rs +++ b/crates/chat-cli/src/cli/chat/command.rs @@ -58,6 +58,7 @@ pub enum Command { path: String, force: bool, }, + Mcp, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -837,6 +838,7 @@ impl Command { } Self::Save { path, force } }, + "mcp" => Self::Mcp, unknown_command => { let looks_like_path = { let after_slash_command_str = parts[1..].join(" "); diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 4e01ea8192..ebe098ce0a 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -107,6 +107,7 @@ use token_counter::{ use tokio::signal::ctrl_c; use tool_manager::{ GetPromptError, + LoadingRecord, McpServerConfig, PromptBundle, ToolManager, @@ -204,7 +205,7 @@ const WELCOME_TEXT: &str = color_print::cstr! {" const SMALL_SCREEN_WELCOME_TEXT: &str = color_print::cstr! {"Welcome to Amazon Q!"}; const RESUME_TEXT: &str = color_print::cstr! {"Picking up where we left off..."}; -const ROTATING_TIPS: [&str; 12] = [ +const ROTATING_TIPS: [&str; 13] = [ color_print::cstr! {"You can resume the last conversation from your current directory by launching with q chat --resume"}, color_print::cstr! {"Get notified whenever Q CLI finishes responding. Just run q settings chat.enableNotifications true"}, color_print::cstr! {"You can use /editor to edit your prompt with a vim-like experience"}, @@ -217,6 +218,7 @@ const ROTATING_TIPS: [&str; 12] = [ color_print::cstr! {"If you want to file an issue to the Q CLI team, just tell me, or run q issue"}, color_print::cstr! {"You can enable custom tools with MCP servers. Learn more with /help"}, color_print::cstr! {"You can specify wait time (in ms) for mcp server loading with q settings mcp.initTimeout {timeout in int}. Servers that takes longer than the specified time will continue to load in the background. Use /tools to see pending servers."}, + color_print::cstr! {"You can see the server load status as well as any warnings or errors associated with /mcp"}, ]; const GREETING_BREAK_POINT: usize = 80; @@ -246,6 +248,7 @@ const HELP_TEXT: &str = color_print::cstr! {" untrust Revert a tool or tools to per-request confirmation trustall Trust all tools (equivalent to deprecated /acceptall) reset Reset all tools to default permission levels +/mcp See mcp server loaded /profile Manage profiles help Show profile help list List profiles @@ -2948,6 +2951,53 @@ impl ChatContext { skip_printing_tools: true, } }, + Command::Mcp => { + let terminal_width = self.terminal_width(); + let loaded_servers = self.conversation_state.tool_manager.mcp_load_record.lock().await; + let still_loading = self + .conversation_state + .tool_manager + .pending_clients() + .await + .into_iter() + .map(|name| format!(" - {name}\n")) + .collect::>() + .join(""); + for (server_name, msg) in loaded_servers.iter() { + let msg = msg + .iter() + .map(|record| match record { + LoadingRecord::Err(content) + | LoadingRecord::Warn(content) + | LoadingRecord::Success(content) => content.clone(), + }) + .collect::>() + .join("\n--- tools refreshed ---\n"); + queue!( + self.output, + style::Print(server_name), + style::Print("\n"), + style::Print(format!("{}\n", "▔".repeat(terminal_width))), + style::Print(msg), + style::Print("\n") + )?; + } + if !still_loading.is_empty() { + queue!( + self.output, + style::Print("Still loading:\n"), + style::Print(format!("{}\n", "▔".repeat(terminal_width))), + style::Print(still_loading), + style::Print("\n") + )?; + } + self.output.flush()?; + ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: true, + } + }, }) } diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs index 3adc665d15..966600fc44 100644 --- a/crates/chat-cli/src/cli/chat/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/server_messenger.rs @@ -14,23 +14,23 @@ use crate::mcp_client::{ }; #[allow(dead_code)] -#[derive(Clone, Debug)] +#[derive(Debug)] pub enum UpdateEventMessage { ToolsListResult { server_name: String, - result: ToolsListResult, + result: eyre::Result, }, PromptsListResult { server_name: String, - result: PromptsListResult, + result: eyre::Result, }, ResourcesListResult { server_name: String, - result: ResourcesListResult, + result: eyre::Result, }, ResourceTemplatesListResult { server_name: String, - result: ResourceTemplatesListResult, + result: eyre::Result, }, InitStart { server_name: String, @@ -67,7 +67,7 @@ pub struct ServerMessenger { #[async_trait::async_trait] impl Messenger for ServerMessenger { - async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError> { + async fn send_tools_list_result(&self, result: eyre::Result) -> Result<(), MessengerError> { Ok(self .update_event_sender .send(UpdateEventMessage::ToolsListResult { @@ -78,7 +78,7 @@ impl Messenger for ServerMessenger { .map_err(|e| MessengerError::Custom(e.to_string()))?) } - async fn send_prompts_list_result(&self, result: PromptsListResult) -> Result<(), MessengerError> { + async fn send_prompts_list_result(&self, result: eyre::Result) -> Result<(), MessengerError> { Ok(self .update_event_sender .send(UpdateEventMessage::PromptsListResult { @@ -89,7 +89,10 @@ impl Messenger for ServerMessenger { .map_err(|e| MessengerError::Custom(e.to_string()))?) } - async fn send_resources_list_result(&self, result: ResourcesListResult) -> Result<(), MessengerError> { + async fn send_resources_list_result( + &self, + result: eyre::Result, + ) -> Result<(), MessengerError> { Ok(self .update_event_sender .send(UpdateEventMessage::ResourcesListResult { @@ -102,7 +105,7 @@ impl Messenger for ServerMessenger { async fn send_resource_templates_list_result( &self, - result: ResourceTemplatesListResult, + result: eyre::Result, ) -> Result<(), MessengerError> { Ok(self .update_event_sender diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index dee5d6c96c..c38d441c76 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -7,7 +7,10 @@ use std::hash::{ DefaultHasher, Hasher, }; -use std::io::Write; +use std::io::{ + BufWriter, + Write, +}; use std::path::{ Path, PathBuf, @@ -17,11 +20,14 @@ use std::sync::atomic::{ AtomicBool, Ordering, }; -use std::sync::mpsc::RecvTimeoutError; use std::sync::{ Arc, RwLock as SyncRwLock, }; +use std::time::{ + Duration, + Instant, +}; use convert_case::Casing; use crossterm::{ @@ -45,6 +51,7 @@ use thiserror::Error; use tokio::signal::ctrl_c; use tokio::sync::{ Mutex, + Notify, RwLock, }; use tracing::{ @@ -84,6 +91,7 @@ use crate::database::Database; use crate::database::settings::Setting; use crate::mcp_client::{ JsonRpcResponse, + Messenger, PromptGet, }; use crate::platform::Context; @@ -126,34 +134,38 @@ pub enum GetPromptError { /// display thread. These messages control the visual loading indicators shown to /// the user during tool initialization. enum LoadingMsg { - /// Indicates a new tool is being initialized and should be added to the loading - /// display. The String parameter is the name of the tool being initialized. - Add(String), /// Indicates a tool has finished initializing successfully and should be removed from /// the loading display. The String parameter is the name of the tool that /// completed initialization. - Done(String), + Done { name: String, time: String }, /// Represents an error that occurred during tool initialization. /// Contains the name of the server that failed to initialize and the error message. - Error { name: String, msg: eyre::Report }, + Error { + name: String, + msg: eyre::Report, + time: String, + }, /// Represents a warning that occurred during tool initialization. /// Contains the name of the server that generated the warning and the warning message. - Warn { name: String, msg: eyre::Report }, + Warn { + name: String, + msg: eyre::Report, + time: String, + }, /// Signals that the loading display thread should terminate. /// This is sent when all tool initialization is complete or when the application is shutting /// down. - Terminate, + Terminate { still_loading: Vec }, } -/// Represents the state of a loading indicator for a tool being initialized. -/// -/// This struct tracks timing information for each tool's loading status display in the terminal. -/// -/// # Fields -/// * `init_time` - When initialization for this tool began, used to calculate load time -struct StatusLine { - init_time: std::time::Instant, - is_done: bool, +/// Used to denote the loading outcome associated with a server. +/// This is mainly used in the non-interactive mode to determine if there is any fatal errors to +/// surface (since we would only want to surface fatal errors in non-interactive mode). +#[derive(Clone, Debug)] +pub enum LoadingRecord { + Success(String), + Warn(String), + Err(String), } // This is to mirror claude's config set up @@ -284,141 +296,106 @@ impl ToolManagerBuilder { (sanitized_server_name, custom_tool_client) }) .collect::>(); + let mut loading_servers = HashMap::::new(); + for (server_name, _) in &pre_initialized { + let init_time = std::time::Instant::now(); + loading_servers.insert(server_name.clone(), init_time); + } + let total = loading_servers.len(); - // Send up task to update user on server loading status - let (tx, rx) = std::sync::mpsc::channel::(); - // TODO: rather than using it as an "anchor" to determine the progress of server loads, we - // should make this task optional (and it is defined as an optional right now. There is - // just no code path with it being None). When ran with no-interactive mode, we really do - // not have a need to run this task. - let loading_display_task = tokio::task::spawn_blocking(move || { - let mut loading_servers = HashMap::::new(); - let mut spinner_logo_idx: usize = 0; - let mut complete: usize = 0; - let mut failed: usize = 0; - loop { - match rx.recv_timeout(std::time::Duration::from_millis(50)) { - Ok(recv_result) => match recv_result { - LoadingMsg::Add(name) => { - let init_time = std::time::Instant::now(); - let is_done = false; - let status_line = StatusLine { init_time, is_done }; - execute!(output, cursor::MoveToColumn(0))?; - if !loading_servers.is_empty() { - // TODO: account for terminal width - execute!(output, cursor::MoveUp(1))?; - } - loading_servers.insert(name.clone(), status_line); - let total = loading_servers.len(); - execute!(output, terminal::Clear(terminal::ClearType::CurrentLine))?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - output.flush()?; - }, - LoadingMsg::Done(name) => { - if let Some(status_line) = loading_servers.get_mut(&name) { - status_line.is_done = true; - complete += 1; - let time_taken = - (std::time::Instant::now() - status_line.init_time).as_secs_f64().abs(); - let time_taken = format!("{:.2}", time_taken); - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_success_message(&name, &time_taken, &mut output)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - output.flush()?; - } - if loading_servers.iter().all(|(_, status)| status.is_done) { - break; - } - }, - LoadingMsg::Error { name, msg } => { - if let Some(status_line) = loading_servers.get_mut(&name) { - status_line.is_done = true; - failed += 1; - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_failure_message(&name, &msg, &mut output)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - } - if loading_servers.iter().all(|(_, status)| status.is_done) { - break; - } - }, - LoadingMsg::Warn { name, msg } => { - if let Some(status_line) = loading_servers.get_mut(&name) { - status_line.is_done = true; - complete += 1; - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = eyre::eyre!(msg.to_string()); - queue_warn_message(&name, &msg, &mut output)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - output.flush()?; - } - if loading_servers.iter().all(|(_, status)| status.is_done) { - break; - } - }, - LoadingMsg::Terminate => { - if loading_servers.iter().any(|(_, status)| !status.is_done) { + // Spawn a task for displaying the mcp loading statuses. + // This is only necessary when we are in interactive mode AND there are servers to load. + // Otherwise we do not need to be spawning this. + let (_loading_display_task, loading_status_sender) = if is_interactive && total > 0 { + let (tx, mut rx) = tokio::sync::mpsc::channel::(50); + ( + Some(tokio::task::spawn(async move { + let mut spinner_logo_idx: usize = 0; + let mut complete: usize = 0; + let mut failed: usize = 0; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + loop { + match tokio::time::timeout(Duration::from_millis(50), rx.recv()).await { + Ok(Some(recv_result)) => match recv_result { + LoadingMsg::Done { name, time } => { + complete += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_success_message(&name, &time, &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Error { name, msg, time } => { + failed += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_failure_message(&name, &msg, time.as_str(), &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Warn { name, msg, time } => { + complete += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = eyre::eyre!(msg.to_string()); + queue_warn_message(&name, &msg, time.as_str(), &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Terminate { still_loading } => { + if !still_loading.is_empty() { + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = still_loading.iter().fold(String::new(), |mut acc, server_name| { + acc.push_str(format!("\n - {server_name}").as_str()); + acc + }); + let msg = eyre::eyre!(msg); + queue_incomplete_load_message(complete, total, &msg, &mut output)?; + } + execute!(output, style::Print("\n"),)?; + break; + }, + }, + Err(_e) => { + spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); execute!( output, + cursor::SavePosition, cursor::MoveToColumn(0), cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), + style::Print(SPINNER_CHARS[spinner_logo_idx]), + cursor::RestorePosition )?; - let msg = - loading_servers - .iter() - .fold(String::new(), |mut acc, (server_name, status)| { - if !status.is_done { - acc.push_str(format!("\n - {server_name}").as_str()); - } - acc - }); - let msg = eyre::eyre!(msg); - let total = loading_servers.len(); - queue_incomplete_load_message(complete, total, &msg, &mut output)?; - } - execute!(output, style::Print("\n"),)?; - break; - }, - }, - Err(RecvTimeoutError::Timeout) => { - spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); - execute!( - output, - cursor::SavePosition, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - style::Print(SPINNER_CHARS[spinner_logo_idx]), - cursor::RestorePosition - )?; - }, - _ => break, - } - } - Ok::<_, eyre::Report>(()) - }); + }, + _ => break, + } + output.flush()?; + } + Ok::<_, eyre::Report>(()) + })), + Some(tx), + ) + } else { + (None, None) + }; let mut clients = HashMap::>::new(); - let mut load_msg_sender = Some(tx.clone()); + let mut loading_status_sender_clone = loading_status_sender.clone(); let conv_id_clone = conversation_id.clone(); - let regex = Arc::new(Regex::new(VALID_TOOL_NAME)?); + let regex = Regex::new(VALID_TOOL_NAME)?; let new_tool_specs = Arc::new(Mutex::new(HashMap::new())); let new_tool_specs_clone = new_tool_specs.clone(); let has_new_stuff = Arc::new(AtomicBool::new(false)); @@ -427,51 +404,145 @@ impl ToolManagerBuilder { let pending_clone = pending.clone(); let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); let telemetry_clone = telemetry.clone(); + let notify = Arc::new(Notify::new()); + let notify_weak = Arc::downgrade(¬ify); + let load_record = Arc::new(Mutex::new(HashMap::>::new())); + let load_record_clone = load_record.clone(); tokio::spawn(async move { + let mut record_temp_buf = Vec::::new(); + let mut initialized = HashSet::::new(); while let Some(msg) = msg_rx.recv().await { + record_temp_buf.clear(); // For now we will treat every list result as if they contain the // complete set of tools. This is not necessarily true in the future when // request method on the mcp client no longer buffers all the pages from // list calls. match msg { UpdateEventMessage::ToolsListResult { server_name, result } => { + let time_taken = loading_servers + .remove(&server_name) + .map_or("0.0".to_owned(), |init_time| { + let time_taken = (std::time::Instant::now() - init_time).as_secs_f64().abs(); + format!("{:.2}", time_taken) + }); pending_clone.write().await.remove(&server_name); - let mut specs = result - .tools - .into_iter() - .filter_map(|v| serde_json::from_value::(v).ok()) - .collect::>(); - let mut sanitized_mapping = HashMap::::new(); - if let Some(load_msg) = process_tool_specs( - conv_id_clone.as_str(), - &server_name, - load_msg_sender.is_some(), - &mut specs, - &mut sanitized_mapping, - ®ex, - &telemetry_clone, - ) { - let mut has_errored = false; - if let Some(sender) = &load_msg_sender { - if let Err(e) = sender.send(load_msg) { - warn!( - "Error sending update message to display task: {:?}\nAssume display task has completed", - e + match result { + Ok(result) => { + let mut specs = result + .tools + .into_iter() + .filter_map(|v| serde_json::from_value::(v).ok()) + .collect::>(); + let mut sanitized_mapping = HashMap::::new(); + let process_result = process_tool_specs( + conv_id_clone.as_str(), + &server_name, + &mut specs, + &mut sanitized_mapping, + ®ex, + &telemetry_clone, + ); + if let Some(sender) = &loading_status_sender_clone { + // Anomalies here are not considered fatal, thus we shall give + // warnings. + let msg = match process_result { + Ok(_) => LoadingMsg::Done { + name: server_name.clone(), + time: time_taken.clone(), + }, + Err(ref e) => LoadingMsg::Warn { + name: server_name.clone(), + msg: eyre::eyre!(e.to_string()), + time: time_taken.clone(), + }, + }; + if let Err(e) = sender.send(msg).await { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + loading_status_sender_clone.take(); + } + } + new_tool_specs_clone + .lock() + .await + .insert(server_name.clone(), (sanitized_mapping, specs)); + has_new_stuff_clone.store(true, Ordering::Release); + // Maintain a record of the server load: + let mut buf_writer = BufWriter::new(&mut record_temp_buf); + if let Err(e) = &process_result { + let _ = queue_warn_message( + server_name.as_str(), + e, + time_taken.as_str(), + &mut buf_writer, + ); + } else { + let _ = queue_success_message( + server_name.as_str(), + time_taken.as_str(), + &mut buf_writer, ); - has_errored = true; } - } - if has_errored { - load_msg_sender.take(); - } + let _ = buf_writer.flush(); + drop(buf_writer); + let record = String::from_utf8_lossy(&record_temp_buf).to_string(); + let record = if process_result.is_err() { + LoadingRecord::Warn(record) + } else { + LoadingRecord::Success(record) + }; + load_record_clone + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + }, + Err(e) => { + // Log error to chat Log + error!("Error loading server {server_name}: {:?}", e); + // Maintain a record of the server load: + let mut buf_writer = BufWriter::new(&mut record_temp_buf); + let _ = queue_failure_message(server_name.as_str(), &e, &time_taken, &mut buf_writer); + let _ = buf_writer.flush(); + drop(buf_writer); + let record = String::from_utf8_lossy(&record_temp_buf).to_string(); + let record = LoadingRecord::Err(record); + load_record_clone + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + // Errors surfaced at this point (i.e. before [process_tool_specs] + // is called) are fatals and should be considered errors + if let Some(sender) = &loading_status_sender_clone { + let msg = LoadingMsg::Error { + name: server_name.clone(), + msg: e, + time: time_taken, + }; + if let Err(e) = sender.send(msg).await { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + loading_status_sender_clone.take(); + } + } + }, } - new_tool_specs_clone - .lock() - .await - .insert(server_name, (sanitized_mapping, specs)); - // We only want to set this flag when the display task has ended - if load_msg_sender.is_none() { - has_new_stuff_clone.store(true, Ordering::Release); + if let Some(notify) = notify_weak.upgrade() { + initialized.insert(server_name); + if initialized.len() >= total { + notify.notify_one(); + } } }, UpdateEventMessage::PromptsListResult { @@ -488,17 +559,15 @@ impl ToolManagerBuilder { } => {}, UpdateEventMessage::InitStart { server_name } => { pending_clone.write().await.insert(server_name.clone()); - if let Some(sender) = &load_msg_sender { - let _ = sender.send(LoadingMsg::Add(server_name)); - } + loading_servers.insert(server_name, std::time::Instant::now()); }, } } }); for (mut name, init_res) in pre_initialized { + let messenger = messenger_builder.build_with_name(name.clone()); match init_res { Ok(mut client) => { - let messenger = messenger_builder.build_with_name(client.get_server_name().to_owned()); client.assign_messenger(Box::new(messenger)); let mut client = Arc::new(client); while let Some(collided_client) = clients.insert(name.clone(), client) { @@ -513,16 +582,10 @@ impl ToolManagerBuilder { telemetry .send_mcp_server_init(conversation_id.clone(), Some(e.to_string()), 0) .ok(); - - let _ = tx.send(LoadingMsg::Error { - name: name.clone(), - msg: e, - }); + let _ = messenger.send_tools_list_result(Err(e)).await; }, } } - let loading_display_task = Some(loading_display_task); - let loading_status_sender = Some(tx); // Set up task to handle prompt requests let sender = self.prompt_list_sender.take(); @@ -617,12 +680,13 @@ impl ToolManagerBuilder { conversation_id, clients, prompts, - loading_display_task, pending_clients: pending, + notify: Some(notify), loading_status_sender, new_tool_specs, has_new_stuff, is_interactive, + mcp_load_record: load_record, ..Default::default() }) } @@ -686,13 +750,13 @@ pub struct ToolManager { /// cases where multiple servers offer prompts with the same name. pub prompts: Arc>>>, - /// Handle to the thread that displays loading status for tool initialization. - /// This thread provides visual feedback to users during the tool loading process. - loading_display_task: Option>>, + /// A notifier to understand if the initial loading has completed. + /// This is only used for initial loading and is discarded after. + notify: Option>, /// Channel sender for communicating with the loading display thread. /// Used to send status updates about tool initialization progress. - loading_status_sender: Option>, + loading_status_sender: Option>, /// Mapping from sanitized tool names to original tool names. /// This is used to handle tool name transformations that may occur during initialization @@ -705,6 +769,13 @@ pub struct ToolManager { pub schema: HashMap, is_interactive: bool, + + /// This serves as a record of the loading of mcp servers. + /// The key of which is the server name as they are recognized by the current instance of chat + /// (which may be different than how it is written in the config, depending of the presence of + /// invalid characters). + /// The value is the load message (i.e. load time, warnings, and errors) + pub mcp_load_record: Arc>>>, } impl Clone for ToolManager { @@ -718,6 +789,7 @@ impl Clone for ToolManager { tn_map: self.tn_map.clone(), schema: self.schema.clone(), is_interactive: self.is_interactive, + mcp_load_record: self.mcp_load_record.clone(), ..Default::default() } } @@ -730,7 +802,7 @@ impl ToolManager { output: &mut SharedWriter, ) -> eyre::Result> { let tx = self.loading_status_sender.take(); - let display_task = self.loading_display_task.take(); + let notify = self.notify.take(); self.schema = { let mut tool_specs = serde_json::from_str::>(include_str!("tools/tool_index.json"))?; @@ -755,17 +827,6 @@ impl ToolManager { }); // We need to cast it to erase the type otherwise the compiler will default to static // dispatch, which would result in an error of inconsistent match arm return type. - let display_fut: Pin>> = match display_task { - Some(display_task) => { - let fut = async move { - if let Err(e) = display_task.await { - error!("Error while joining status display task: {:?}", e); - } - }; - Box::pin(fut) - }, - None => Box::pin(future::pending()), - }; let timeout_fut: Pin>> = if self.clients.is_empty() { // If there is no server loaded, we want to resolve immediately Box::pin(future::ready(())) @@ -781,34 +842,62 @@ impl ToolManager { .settings .get_int(Setting::McpNoInteractiveTimeout) .map_or(30_000_u64, |s| s as u64); - Box::pin(async move { - tokio::time::sleep(std::time::Duration::from_millis(init_timeout)).await; - let _ = queue!( - output, - style::Print( - "Not all mcp servers loaded. Configure no-interactive timeout with q settings mcp.noInteractiveTimeout" - ), - style::Print("\n") - ); - }) + Box::pin(tokio::time::sleep(std::time::Duration::from_millis(init_timeout))) + }; + let server_loading_fut: Pin>> = if let Some(notify) = notify { + Box::pin(async move { notify.notified().await }) + } else { + Box::pin(future::ready(())) }; tokio::select! { - _ = display_fut => {}, _ = timeout_fut => { if let Some(tx) = tx { - let _ = tx.send(LoadingMsg::Terminate); + let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); + let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; + } + if !self.clients.is_empty() && !self.is_interactive { + let _ = queue!( + output, + style::Print( + "Not all mcp servers loaded. Configure no-interactive timeout with q settings mcp.noInteractiveTimeout" + ), + style::Print("\n------\n") + ); } }, + _ = server_loading_fut => { + if let Some(tx) = tx { + let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); + let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; + } + } _ = ctrl_c() => { if self.is_interactive { if let Some(tx) = tx { - let _ = tx.send(LoadingMsg::Terminate); + let still_loading = self.pending_clients.read().await.iter().cloned().collect::>(); + let _ = tx.send(LoadingMsg::Terminate { still_loading }).await; } } else { return Err(eyre::eyre!("User interrupted mcp server loading in non-interactive mode. Ending.")); } } } + if !self.is_interactive + && self + .mcp_load_record + .lock() + .await + .iter() + .any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(_)))) + { + queue!( + output, + style::Print( + "One or more mcp server did not load correctly. See $TMPDIR/qlog/chat.log for more details." + ), + style::Print("\n------\n") + )?; + } self.update().await; Ok(self.schema.clone()) } @@ -1118,12 +1207,11 @@ impl ToolManager { fn process_tool_specs( conversation_id: &str, server_name: &str, - is_in_display: bool, specs: &mut Vec, tn_map: &mut HashMap, - regex: &Arc, + regex: &Regex, telemetry: &TelemetryThread, -) -> Option { +) -> eyre::Result<()> { // Each mcp server might have multiple tools. // To avoid naming conflicts we are going to namespace it. // This would also help us locate which mcp server to call the tool from. @@ -1174,8 +1262,8 @@ fn process_tool_specs( // considered a "server load". Reasoning being: // - Failures here are not related to server load // - There is not a whole lot we can do with this data - let loading_msg = if !out_of_spec_tool_names.is_empty() { - let msg = out_of_spec_tool_names.iter().fold( + if !out_of_spec_tool_names.is_empty() { + Err(eyre::eyre!(out_of_spec_tool_names.iter().fold( String::from( "The following tools are out of spec. They will be excluded from the list of available tools:\n", ), @@ -1193,46 +1281,23 @@ fn process_tool_specs( (tool_name.as_str(), "tool schema contains empty description") }, }; - acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); + acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); acc }, - ); - error!( - "Server {} finished loading with the following error: \n{}", - server_name, msg - ); - if is_in_display { - Some(LoadingMsg::Warn { - name: server_name.to_string(), - msg: eyre::eyre!(msg), - }) - } else { - None - } + ))) // TODO: if no tools are valid, we need to offload the server // from the fleet (i.e. kill the server) } else if !tn_map.is_empty() { - let warn = tn_map.iter().fold( + Err(eyre::eyre!(tn_map.iter().fold( String::from("The following tool names are changed:\n"), |mut acc, (k, v)| { acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); acc }, - ); - if is_in_display { - Some(LoadingMsg::Warn { - name: server_name.to_string(), - msg: eyre::eyre!(warn), - }) - } else { - None - } - } else if is_in_display { - Some(LoadingMsg::Done(server_name.to_string())) + ))) } else { - None - }; - loading_msg + Ok(()) + } } fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { @@ -1272,6 +1337,7 @@ fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) style::Print(" loaded in "), style::SetForegroundColor(style::Color::Yellow), style::Print(format!("{time_taken} s\n")), + style::ResetColor, )?) } @@ -1322,7 +1388,12 @@ fn queue_init_message( Ok(queue!(output, style::Print("\n"))?) } -fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { +fn queue_failure_message( + name: &str, + fail_load_msg: &eyre::Report, + time: &str, + output: &mut impl Write, +) -> eyre::Result<()> { use crate::util::CHAT_BINARY_NAME; Ok(queue!( output, @@ -1331,17 +1402,21 @@ fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut style::SetForegroundColor(style::Color::Blue), style::Print(name), style::ResetColor, - style::Print(" has failed to load:\n- "), + style::Print(" has failed to load after"), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!(" {time} s")), + style::ResetColor, + style::Print("\n - "), style::Print(fail_load_msg), style::Print("\n"), style::Print(format!( - "- run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n" + " - run with Q_LOG_LEVEL=trace and see $TMPDIR/{CHAT_BINARY_NAME} for detail\n" )), style::ResetColor, )?) } -fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { +fn queue_warn_message(name: &str, msg: &eyre::Report, time: &str, output: &mut impl Write) -> eyre::Result<()> { Ok(queue!( output, style::SetForegroundColor(style::Color::Yellow), @@ -1349,7 +1424,11 @@ fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) - style::SetForegroundColor(style::Color::Blue), style::Print(name), style::ResetColor, - style::Print(" has the following warning:\n"), + style::Print(" has loaded in"), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!(" {time} s")), + style::ResetColor, + style::Print(" with the following warning:\n"), style::Print(msg), style::ResetColor, )?) diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index b0d8eefe00..831427d3b1 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -573,28 +573,27 @@ where { // TODO: decouple pagination logic from request and have page fetching logic here // instead - let resp = match client.request("tools/list", None).await { - Ok(resp) => resp, - Err(e) => { - tracing::error!("Failed to retrieve tool list from {}: {:?}", client.server_name, e); - return; - }, - }; - if let Some(error) = resp.error { - let msg = format!("Failed to retrieve tool list for {}: {:?}", client.server_name, error); - tracing::error!("{}", &msg); - return; - } - let Some(result) = resp.result else { - tracing::error!("Tool list response from {} is missing result", client.server_name); - return; - }; - let tool_list_result = match serde_json::from_value::(result) { - Ok(result) => result, - Err(e) => { - tracing::error!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); - return; - }, + let tool_list_result = 'tool_list_result: { + let resp = match client.request("tools/list", None).await { + Ok(resp) => resp, + Err(e) => break 'tool_list_result Err(e.into()), + }; + if let Some(error) = resp.error { + let msg = format!("Failed to retrieve tool list for {}: {:?}", client.server_name, error); + break 'tool_list_result Err(eyre::eyre!(msg)); + } + let Some(result) = resp.result else { + let msg = format!("Tool list response from {} is missing result", client.server_name); + break 'tool_list_result Err(eyre::eyre!(msg)); + }; + let tool_list_result = match serde_json::from_value::(result) { + Ok(result) => result, + Err(e) => { + let msg = format!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); + break 'tool_list_result Err(eyre::eyre!(msg)); + }, + }; + Ok::(tool_list_result) }; if let Some(messenger) = messenger { let _ = messenger diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs index efd49617ab..14f79e518a 100644 --- a/crates/chat-cli/src/mcp_client/messenger.rs +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -16,21 +16,22 @@ use super::{ pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { /// Sends the result of a tools list operation to the consumer /// This function is used to deliver information about available tools - async fn send_tools_list_result(&self, result: ToolsListResult) -> Result<(), MessengerError>; + async fn send_tools_list_result(&self, result: eyre::Result) -> Result<(), MessengerError>; /// Sends the result of a prompts list operation to the consumer /// This function is used to deliver information about available prompts - async fn send_prompts_list_result(&self, result: PromptsListResult) -> Result<(), MessengerError>; + async fn send_prompts_list_result(&self, result: eyre::Result) -> Result<(), MessengerError>; /// Sends the result of a resources list operation to the consumer /// This function is used to deliver information about available resources - async fn send_resources_list_result(&self, result: ResourcesListResult) -> Result<(), MessengerError>; + async fn send_resources_list_result(&self, result: eyre::Result) + -> Result<(), MessengerError>; /// Sends the result of a resource templates list operation to the consumer /// This function is used to deliver information about available resource templates async fn send_resource_templates_list_result( &self, - result: ResourceTemplatesListResult, + result: eyre::Result, ) -> Result<(), MessengerError>; /// Signals to the orchestrator that a server has started initializing @@ -52,21 +53,24 @@ pub struct NullMessenger; #[async_trait::async_trait] impl Messenger for NullMessenger { - async fn send_tools_list_result(&self, _result: ToolsListResult) -> Result<(), MessengerError> { + async fn send_tools_list_result(&self, _result: eyre::Result) -> Result<(), MessengerError> { Ok(()) } - async fn send_prompts_list_result(&self, _result: PromptsListResult) -> Result<(), MessengerError> { + async fn send_prompts_list_result(&self, _result: eyre::Result) -> Result<(), MessengerError> { Ok(()) } - async fn send_resources_list_result(&self, _result: ResourcesListResult) -> Result<(), MessengerError> { + async fn send_resources_list_result( + &self, + _result: eyre::Result, + ) -> Result<(), MessengerError> { Ok(()) } async fn send_resource_templates_list_result( &self, - _result: ResourceTemplatesListResult, + _result: eyre::Result, ) -> Result<(), MessengerError> { Ok(()) }