Skip to content

Commit f767058

Browse files
committed
adds slash command to show mcp server load messages
1 parent 8649192 commit f767058

File tree

3 files changed

+113
-10
lines changed

3 files changed

+113
-10
lines changed

crates/chat-cli/src/cli/chat/command.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ pub enum Command {
5858
path: String,
5959
force: bool,
6060
},
61+
Mcp,
6162
}
6263

6364
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -837,6 +838,7 @@ impl Command {
837838
}
838839
Self::Save { path, force }
839840
},
841+
"mcp" => Self::Mcp,
840842
unknown_command => {
841843
let looks_like_path = {
842844
let after_slash_command_str = parts[1..].join(" ");

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2948,6 +2948,44 @@ impl ChatContext {
29482948
skip_printing_tools: true,
29492949
}
29502950
},
2951+
Command::Mcp => {
2952+
let terminal_width = self.terminal_width();
2953+
let loaded_servers = self.conversation_state.tool_manager.mcp_load_record.lock().await;
2954+
let still_loading = self
2955+
.conversation_state
2956+
.tool_manager
2957+
.pending_clients()
2958+
.await
2959+
.into_iter()
2960+
.map(|name| format!(" - {name}\n"))
2961+
.collect::<Vec<_>>()
2962+
.join("");
2963+
for (server_name, msg) in loaded_servers.iter() {
2964+
queue!(
2965+
self.output,
2966+
style::Print(server_name),
2967+
style::Print("\n"),
2968+
style::Print(format!("{}\n", "▔".repeat(terminal_width))),
2969+
style::Print(msg),
2970+
style::Print("\n")
2971+
)?;
2972+
}
2973+
if !still_loading.is_empty() {
2974+
queue!(
2975+
self.output,
2976+
style::Print("Still loading:\n"),
2977+
style::Print(format!("{}\n", "▔".repeat(terminal_width))),
2978+
style::Print(still_loading),
2979+
style::Print("\n")
2980+
)?;
2981+
}
2982+
self.output.flush()?;
2983+
ChatState::PromptUser {
2984+
tool_uses: None,
2985+
pending_tool_index: None,
2986+
skip_printing_tools: true,
2987+
}
2988+
},
29512989
})
29522990
}
29532991

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ use std::hash::{
77
DefaultHasher,
88
Hasher,
99
};
10-
use std::io::Write;
10+
use std::io::{
11+
BufWriter,
12+
Write,
13+
};
1114
use std::path::{
1215
Path,
1316
PathBuf,
@@ -395,18 +398,24 @@ impl ToolManagerBuilder {
395398
let notify = Arc::new(Notify::new());
396399
let notify_weak = Arc::downgrade(&notify);
397400
let completed_clone = completed.clone();
401+
let load_record = Arc::new(Mutex::new(HashMap::<String, String>::new()));
402+
let load_record_clone = load_record.clone();
398403
tokio::spawn(async move {
404+
let mut record_temp_buf = Vec::<u8>::new();
399405
while let Some(msg) = msg_rx.recv().await {
406+
record_temp_buf.clear();
400407
// For now we will treat every list result as if they contain the
401408
// complete set of tools. This is not necessarily true in the future when
402409
// request method on the mcp client no longer buffers all the pages from
403410
// list calls.
404411
match msg {
405412
UpdateEventMessage::ToolsListResult { server_name, result } => {
406-
let time_taken = loading_servers.get(&server_name).map_or("0.0".to_owned(), |init_time| {
407-
let time_taken = (std::time::Instant::now() - *init_time).as_secs_f64().abs();
408-
format!("{:.2}", time_taken)
409-
});
413+
let time_taken = loading_servers
414+
.remove(&server_name)
415+
.map_or("0.0".to_owned(), |init_time| {
416+
let time_taken = (std::time::Instant::now() - init_time).as_secs_f64().abs();
417+
format!("{:.2}", time_taken)
418+
});
410419
pending_clone.write().await.remove(&server_name);
411420
match result {
412421
Ok(result) => {
@@ -430,12 +439,12 @@ impl ToolManagerBuilder {
430439
let msg = match process_result {
431440
Ok(_) => LoadingMsg::Done {
432441
name: server_name.clone(),
433-
time: time_taken,
442+
time: time_taken.clone(),
434443
},
435-
Err(e) => LoadingMsg::Warn {
444+
Err(ref e) => LoadingMsg::Warn {
436445
name: server_name.clone(),
437-
msg: e,
438-
time: time_taken,
446+
msg: eyre::eyre!(e.to_string()),
447+
time: time_taken.clone(),
439448
},
440449
};
441450
if let Err(e) = sender.send(msg).await {
@@ -449,10 +458,53 @@ impl ToolManagerBuilder {
449458
new_tool_specs_clone
450459
.lock()
451460
.await
452-
.insert(server_name, (sanitized_mapping, specs));
461+
.insert(server_name.clone(), (sanitized_mapping, specs));
453462
has_new_stuff_clone.store(true, Ordering::Release);
463+
// Maintain a record of the server load:
464+
let mut buf_writer = BufWriter::new(&mut record_temp_buf);
465+
if let Err(e) = process_result {
466+
let _ = queue_warn_message(
467+
server_name.as_str(),
468+
&e,
469+
time_taken.as_str(),
470+
&mut buf_writer,
471+
);
472+
} else {
473+
let _ = queue_success_message(
474+
server_name.as_str(),
475+
time_taken.as_str(),
476+
&mut buf_writer,
477+
);
478+
}
479+
let _ = buf_writer.flush();
480+
drop(buf_writer);
481+
let record = String::from_utf8_lossy(&record_temp_buf).to_string();
482+
load_record_clone
483+
.lock()
484+
.await
485+
.entry(server_name)
486+
.and_modify(|load_record| {
487+
load_record.push_str("\n--- tools refreshed ---\n");
488+
load_record.push_str(record.as_str());
489+
})
490+
.or_insert(record);
454491
},
455492
Err(e) => {
493+
// Maintain a record of the server load:
494+
let mut buf_writer = BufWriter::new(&mut record_temp_buf);
495+
let _ = queue_failure_message(server_name.as_str(), &e, &time_taken, &mut buf_writer);
496+
let _ = buf_writer.flush();
497+
drop(buf_writer);
498+
let record = String::from_utf8_lossy(&record_temp_buf).to_string();
499+
load_record_clone
500+
.lock()
501+
.await
502+
.entry(server_name.clone())
503+
.and_modify(|load_record| {
504+
load_record.push_str("\n--- tools refreshed ---\n");
505+
load_record.push_str(record.as_str());
506+
})
507+
.or_insert(record);
456508
// Errors surfaced at this point (i.e. before [process_tool_specs]
457509
// is called) are fatals and should be considered errors
458510
if let Some(sender) = &loading_status_sender_clone {
@@ -492,6 +544,7 @@ impl ToolManagerBuilder {
492544
} => {},
493545
UpdateEventMessage::InitStart { server_name } => {
494546
pending_clone.write().await.insert(server_name.clone());
547+
loading_servers.insert(server_name, std::time::Instant::now());
495548
},
496549
}
497550
}
@@ -627,6 +680,7 @@ impl ToolManagerBuilder {
627680
new_tool_specs,
628681
has_new_stuff,
629682
is_interactive,
683+
mcp_load_record: load_record,
630684
..Default::default()
631685
})
632686
}
@@ -709,6 +763,13 @@ pub struct ToolManager {
709763
pub schema: HashMap<String, ToolSpec>,
710764

711765
is_interactive: bool,
766+
767+
/// This serves as a record of the loading of mcp servers.
768+
/// The key of which is the server name as they are recognized by the current instance of chat
769+
/// (which may be different than how it is written in the config, depending of the presence of
770+
/// invalid characters).
771+
/// The value is the load message (i.e. load time, warnings, and errors)
772+
pub mcp_load_record: Arc<Mutex<HashMap<String, String>>>,
712773
}
713774

714775
impl Clone for ToolManager {
@@ -722,6 +783,7 @@ impl Clone for ToolManager {
722783
tn_map: self.tn_map.clone(),
723784
schema: self.schema.clone(),
724785
is_interactive: self.is_interactive,
786+
mcp_load_record: self.mcp_load_record.clone(),
725787
..Default::default()
726788
}
727789
}
@@ -1253,6 +1315,7 @@ fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write)
12531315
style::Print(" loaded in "),
12541316
style::SetForegroundColor(style::Color::Yellow),
12551317
style::Print(format!("{time_taken} s\n")),
1318+
style::ResetColor,
12561319
)?)
12571320
}
12581321

0 commit comments

Comments
 (0)