Skip to content

Commit c4526cf

Browse files
committed
uses a hash set of server names to keep track of the number of initialized servers
1 parent cc4b435 commit c4526cf

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ use std::path::{
1818
use std::pin::Pin;
1919
use std::sync::atomic::{
2020
AtomicBool,
21-
AtomicU32,
2221
Ordering,
2322
};
2423
use std::sync::{
@@ -303,7 +302,6 @@ impl ToolManagerBuilder {
303302
loading_servers.insert(server_name.clone(), init_time);
304303
}
305304
let total = loading_servers.len();
306-
let completed = Arc::new(AtomicU32::new(0));
307305

308306
// Spawn a task for displaying the mcp loading statuses.
309307
// This is only necessary when we are in interactive mode AND there are servers to load.
@@ -408,11 +406,11 @@ impl ToolManagerBuilder {
408406
let telemetry_clone = telemetry.clone();
409407
let notify = Arc::new(Notify::new());
410408
let notify_weak = Arc::downgrade(&notify);
411-
let completed_clone = completed.clone();
412409
let load_record = Arc::new(Mutex::new(HashMap::<String, Vec<LoadingRecord>>::new()));
413410
let load_record_clone = load_record.clone();
414411
tokio::spawn(async move {
415412
let mut record_temp_buf = Vec::<u8>::new();
413+
let mut initialized = HashSet::<String>::new();
416414
while let Some(msg) = msg_rx.recv().await {
417415
record_temp_buf.clear();
418416
// For now we will treat every list result as if they contain the
@@ -498,7 +496,7 @@ impl ToolManagerBuilder {
498496
load_record_clone
499497
.lock()
500498
.await
501-
.entry(server_name)
499+
.entry(server_name.clone())
502500
.and_modify(|load_record| {
503501
load_record.push(record.clone());
504502
})
@@ -526,7 +524,7 @@ impl ToolManagerBuilder {
526524
// is called) are fatals and should be considered errors
527525
if let Some(sender) = &loading_status_sender_clone {
528526
let msg = LoadingMsg::Error {
529-
name: server_name,
527+
name: server_name.clone(),
530528
msg: e,
531529
time: time_taken,
532530
};
@@ -541,8 +539,8 @@ impl ToolManagerBuilder {
541539
},
542540
}
543541
if let Some(notify) = notify_weak.upgrade() {
544-
let completed = completed_clone.fetch_add(1, Ordering::AcqRel);
545-
if completed + 1 >= (total as u32) {
542+
initialized.insert(server_name);
543+
if initialized.len() >= total {
546544
notify.notify_one();
547545
}
548546
}
@@ -585,7 +583,6 @@ impl ToolManagerBuilder {
585583
.send_mcp_server_init(conversation_id.clone(), Some(e.to_string()), 0)
586584
.ok();
587585
let _ = messenger.send_tools_list_result(Err(e)).await;
588-
completed.fetch_add(1, Ordering::AcqRel);
589586
},
590587
}
591588
}
@@ -858,6 +855,7 @@ impl ToolManager {
858855
let still_loading = self.pending_clients.read().await.iter().cloned().collect::<Vec<_>>();
859856
let _ = tx.send(LoadingMsg::Terminate { still_loading }).await;
860857
}
858+
error!("## timeout: timed out");
861859
if !self.clients.is_empty() {
862860
let _ = queue!(
863861
output,
@@ -869,12 +867,14 @@ impl ToolManager {
869867
}
870868
},
871869
_ = server_loading_fut => {
870+
error!("## timeout: server load finish");
872871
if let Some(tx) = tx {
873872
let still_loading = self.pending_clients.read().await.iter().cloned().collect::<Vec<_>>();
874873
let _ = tx.send(LoadingMsg::Terminate { still_loading }).await;
875874
}
876875
}
877876
_ = ctrl_c() => {
877+
error!("## timeout: ctrl c");
878878
if self.is_interactive {
879879
if let Some(tx) = tx {
880880
let still_loading = self.pending_clients.read().await.iter().cloned().collect::<Vec<_>>();

0 commit comments

Comments
 (0)