Skip to content

Commit f688866

Browse files
committed
changes display task to use async task instead of spawn blocking
1 parent e0c755e commit f688866

File tree

1 file changed

+129
-126
lines changed

1 file changed

+129
-126
lines changed

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 129 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ use std::sync::atomic::{
1717
AtomicBool,
1818
Ordering,
1919
};
20-
use std::sync::mpsc::RecvTimeoutError;
2120
use std::sync::{
2221
Arc,
2322
RwLock as SyncRwLock,
2423
};
24+
use std::time::Duration;
2525

2626
use convert_case::Casing;
2727
use crossterm::{
@@ -290,122 +290,124 @@ impl ToolManagerBuilder {
290290
}
291291
let total = loading_servers.len();
292292

293-
// Send up task to update user on server loading status
294-
let (tx, rx) = std::sync::mpsc::channel::<LoadingMsg>();
295-
// TODO: rather than using it as an "anchor" to determine the progress of server loads, we
296-
// should make this task optional (and it is defined as an optional right now. There is
297-
// just no code path with it being None). When ran with no-interactive mode, we really do
298-
// not have a need to run this task.
299-
let loading_display_task = tokio::task::spawn_blocking(move || {
300-
if total == 0 {
301-
return Ok::<_, eyre::Report>(());
302-
}
303-
let mut spinner_logo_idx: usize = 0;
304-
let mut complete: usize = 0;
305-
let mut failed: usize = 0;
306-
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
307-
loop {
308-
match rx.recv_timeout(std::time::Duration::from_millis(50)) {
309-
Ok(recv_result) => match recv_result {
310-
LoadingMsg::Done(name) => {
311-
if let Some(status_line) = loading_servers.get_mut(&name) {
312-
status_line.is_done = true;
313-
complete += 1;
314-
let time_taken =
315-
(std::time::Instant::now() - status_line.init_time).as_secs_f64().abs();
316-
let time_taken = format!("{:.2}", time_taken);
317-
execute!(
318-
output,
319-
cursor::MoveToColumn(0),
320-
cursor::MoveUp(1),
321-
terminal::Clear(terminal::ClearType::CurrentLine),
322-
)?;
323-
queue_success_message(&name, &time_taken, &mut output)?;
324-
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
325-
output.flush()?;
326-
}
327-
if loading_servers.iter().all(|(_, status)| status.is_done) {
328-
break;
329-
}
330-
},
331-
LoadingMsg::Error { name, msg } => {
332-
if let Some(status_line) = loading_servers.get_mut(&name) {
333-
status_line.is_done = true;
334-
failed += 1;
335-
execute!(
336-
output,
337-
cursor::MoveToColumn(0),
338-
cursor::MoveUp(1),
339-
terminal::Clear(terminal::ClearType::CurrentLine),
340-
)?;
341-
queue_failure_message(&name, &msg, &mut output)?;
342-
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
343-
}
344-
if loading_servers.iter().all(|(_, status)| status.is_done) {
345-
break;
346-
}
347-
},
348-
LoadingMsg::Warn { name, msg } => {
349-
if let Some(status_line) = loading_servers.get_mut(&name) {
350-
status_line.is_done = true;
351-
complete += 1;
352-
execute!(
353-
output,
354-
cursor::MoveToColumn(0),
355-
cursor::MoveUp(1),
356-
terminal::Clear(terminal::ClearType::CurrentLine),
357-
)?;
358-
let msg = eyre::eyre!(msg.to_string());
359-
queue_warn_message(&name, &msg, &mut output)?;
360-
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
361-
output.flush()?;
362-
}
363-
if loading_servers.iter().all(|(_, status)| status.is_done) {
364-
break;
365-
}
366-
},
367-
LoadingMsg::Terminate => {
368-
if loading_servers.iter().any(|(_, status)| !status.is_done) {
293+
// Spawn a task for displaying the mcp loading statuses.
294+
// This is only necessary when we are in interactive mode AND there are servers to load.
295+
// Otherwise we do not need to be spawning this.
296+
let (loading_display_task, loading_status_sender) = if is_interactive && total > 0 {
297+
let (tx, mut rx) = tokio::sync::mpsc::channel::<LoadingMsg>(50);
298+
(
299+
Some(tokio::task::spawn(async move {
300+
let mut spinner_logo_idx: usize = 0;
301+
let mut complete: usize = 0;
302+
let mut failed: usize = 0;
303+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
304+
loop {
305+
match tokio::time::timeout(Duration::from_millis(50), rx.recv()).await {
306+
Ok(Some(recv_result)) => match recv_result {
307+
LoadingMsg::Done(name) => {
308+
if let Some(status_line) = loading_servers.get_mut(&name) {
309+
status_line.is_done = true;
310+
complete += 1;
311+
let time_taken =
312+
(std::time::Instant::now() - status_line.init_time).as_secs_f64().abs();
313+
let time_taken = format!("{:.2}", time_taken);
314+
execute!(
315+
output,
316+
cursor::MoveToColumn(0),
317+
cursor::MoveUp(1),
318+
terminal::Clear(terminal::ClearType::CurrentLine),
319+
)?;
320+
queue_success_message(&name, &time_taken, &mut output)?;
321+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
322+
output.flush()?;
323+
}
324+
if loading_servers.iter().all(|(_, status)| status.is_done) {
325+
break;
326+
}
327+
},
328+
LoadingMsg::Error { name, msg } => {
329+
if let Some(status_line) = loading_servers.get_mut(&name) {
330+
status_line.is_done = true;
331+
failed += 1;
332+
execute!(
333+
output,
334+
cursor::MoveToColumn(0),
335+
cursor::MoveUp(1),
336+
terminal::Clear(terminal::ClearType::CurrentLine),
337+
)?;
338+
queue_failure_message(&name, &msg, &mut output)?;
339+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
340+
}
341+
if loading_servers.iter().all(|(_, status)| status.is_done) {
342+
break;
343+
}
344+
},
345+
LoadingMsg::Warn { name, msg } => {
346+
if let Some(status_line) = loading_servers.get_mut(&name) {
347+
status_line.is_done = true;
348+
complete += 1;
349+
execute!(
350+
output,
351+
cursor::MoveToColumn(0),
352+
cursor::MoveUp(1),
353+
terminal::Clear(terminal::ClearType::CurrentLine),
354+
)?;
355+
let msg = eyre::eyre!(msg.to_string());
356+
queue_warn_message(&name, &msg, &mut output)?;
357+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
358+
output.flush()?;
359+
}
360+
if loading_servers.iter().all(|(_, status)| status.is_done) {
361+
break;
362+
}
363+
},
364+
LoadingMsg::Terminate => {
365+
if loading_servers.iter().any(|(_, status)| !status.is_done) {
366+
execute!(
367+
output,
368+
cursor::MoveToColumn(0),
369+
cursor::MoveUp(1),
370+
terminal::Clear(terminal::ClearType::CurrentLine),
371+
)?;
372+
let msg = loading_servers.iter().fold(
373+
String::new(),
374+
|mut acc, (server_name, status)| {
375+
if !status.is_done {
376+
acc.push_str(format!("\n - {server_name}").as_str());
377+
}
378+
acc
379+
},
380+
);
381+
let msg = eyre::eyre!(msg);
382+
queue_incomplete_load_message(complete, total, &msg, &mut output)?;
383+
}
384+
execute!(output, style::Print("\n"),)?;
385+
break;
386+
},
387+
},
388+
Err(_e) => {
389+
spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len();
369390
execute!(
370391
output,
392+
cursor::SavePosition,
371393
cursor::MoveToColumn(0),
372394
cursor::MoveUp(1),
373-
terminal::Clear(terminal::ClearType::CurrentLine),
395+
style::Print(SPINNER_CHARS[spinner_logo_idx]),
396+
cursor::RestorePosition
374397
)?;
375-
let msg =
376-
loading_servers
377-
.iter()
378-
.fold(String::new(), |mut acc, (server_name, status)| {
379-
if !status.is_done {
380-
acc.push_str(format!("\n - {server_name}").as_str());
381-
}
382-
acc
383-
});
384-
let msg = eyre::eyre!(msg);
385-
queue_incomplete_load_message(complete, total, &msg, &mut output)?;
386-
}
387-
execute!(output, style::Print("\n"),)?;
388-
break;
389-
},
390-
},
391-
Err(RecvTimeoutError::Timeout) => {
392-
spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len();
393-
execute!(
394-
output,
395-
cursor::SavePosition,
396-
cursor::MoveToColumn(0),
397-
cursor::MoveUp(1),
398-
style::Print(SPINNER_CHARS[spinner_logo_idx]),
399-
cursor::RestorePosition
400-
)?;
401-
},
402-
_ => break,
403-
}
404-
}
405-
Ok::<_, eyre::Report>(())
406-
});
398+
},
399+
_ => break,
400+
}
401+
}
402+
Ok::<_, eyre::Report>(())
403+
})),
404+
Some(tx),
405+
)
406+
} else {
407+
(None, None)
408+
};
407409
let mut clients = HashMap::<String, Arc<CustomToolClient>>::new();
408-
let mut load_msg_sender = Some(tx.clone());
410+
let mut loading_status_sender_clone = loading_status_sender.clone();
409411
let conv_id_clone = conversation_id.clone();
410412
let regex = Arc::new(Regex::new(VALID_TOOL_NAME)?);
411413
let new_tool_specs = Arc::new(Mutex::new(HashMap::new()));
@@ -434,15 +436,15 @@ impl ToolManagerBuilder {
434436
if let Some(load_msg) = process_tool_specs(
435437
conv_id_clone.as_str(),
436438
&server_name,
437-
load_msg_sender.is_some(),
439+
loading_status_sender_clone.is_some(),
438440
&mut specs,
439441
&mut sanitized_mapping,
440442
&regex,
441443
&telemetry_clone,
442444
) {
443445
let mut has_errored = false;
444-
if let Some(sender) = &load_msg_sender {
445-
if let Err(e) = sender.send(load_msg) {
446+
if let Some(sender) = &loading_status_sender_clone {
447+
if let Err(e) = sender.send(load_msg).await {
446448
warn!(
447449
"Error sending update message to display task: {:?}\nAssume display task has completed",
448450
e
@@ -451,15 +453,15 @@ impl ToolManagerBuilder {
451453
}
452454
}
453455
if has_errored {
454-
load_msg_sender.take();
456+
loading_status_sender_clone.take();
455457
}
456458
}
457459
new_tool_specs_clone
458460
.lock()
459461
.await
460462
.insert(server_name, (sanitized_mapping, specs));
461463
// We only want to set this flag when the display task has ended
462-
if load_msg_sender.is_none() {
464+
if loading_status_sender_clone.is_none() {
463465
has_new_stuff_clone.store(true, Ordering::Release);
464466
}
465467
},
@@ -499,16 +501,17 @@ impl ToolManagerBuilder {
499501
telemetry
500502
.send_mcp_server_init(conversation_id.clone(), Some(e.to_string()), 0)
501503
.ok();
502-
503-
let _ = tx.send(LoadingMsg::Error {
504-
name: name.clone(),
505-
msg: e,
506-
});
504+
if let Some(tx) = &loading_status_sender {
505+
let _ = tx
506+
.send(LoadingMsg::Error {
507+
name: name.clone(),
508+
msg: e,
509+
})
510+
.await;
511+
}
507512
},
508513
}
509514
}
510-
let loading_display_task = Some(loading_display_task);
511-
let loading_status_sender = Some(tx);
512515

513516
// Set up task to handle prompt requests
514517
let sender = self.prompt_list_sender.take();
@@ -678,7 +681,7 @@ pub struct ToolManager {
678681

679682
/// Channel sender for communicating with the loading display thread.
680683
/// Used to send status updates about tool initialization progress.
681-
loading_status_sender: Option<std::sync::mpsc::Sender<LoadingMsg>>,
684+
loading_status_sender: Option<tokio::sync::mpsc::Sender<LoadingMsg>>,
682685

683686
/// Mapping from sanitized tool names to original tool names.
684687
/// This is used to handle tool name transformations that may occur during initialization
@@ -764,13 +767,13 @@ impl ToolManager {
764767
_ = display_fut => {},
765768
_ = timeout_fut => {
766769
if let Some(tx) = tx {
767-
let _ = tx.send(LoadingMsg::Terminate);
770+
let _ = tx.send(LoadingMsg::Terminate).await;
768771
}
769772
},
770773
_ = ctrl_c() => {
771774
if self.is_interactive {
772775
if let Some(tx) = tx {
773-
let _ = tx.send(LoadingMsg::Terminate);
776+
let _ = tx.send(LoadingMsg::Terminate).await;
774777
}
775778
} else {
776779
return Err(eyre::eyre!("User interrupted mcp server loading in non-interactive mode. Ending."));

0 commit comments

Comments
 (0)