Skip to content

Commit a00d22a

Browse files
committed
updates tools info per try chat loop
1 parent 0808038 commit a00d22a

File tree

5 files changed

+63
-6
lines changed

5 files changed

+63
-6
lines changed

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use std::process::{
3333
ExitCode,
3434
};
3535
use std::sync::Arc;
36+
use std::sync::atomic::Ordering;
3637
use std::time::Duration;
3738
use std::{
3839
env,
@@ -795,6 +796,16 @@ impl ChatContext {
795796
let ctrl_c_stream = ctrl_c();
796797
debug!(?chat_state, "changing to state");
797798

799+
// Update conversation state with new tool information
800+
if self
801+
.conversation_state
802+
.tool_manager
803+
.has_new_stuff
804+
.load(Ordering::Relaxed)
805+
{
806+
self.conversation_state.update_state().await;
807+
}
808+
798809
let result = match chat_state {
799810
ChatState::PromptUser {
800811
tool_uses,

crates/chat-cli/src/cli/chat/server_messenger.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ pub enum UpdateEventMessage {
3232
server_name: String,
3333
result: ResourceTemplatesListResult,
3434
},
35-
DisplayTaskEnded,
35+
InitStart {
36+
server_name: String,
37+
},
3638
}
3739

3840
#[derive(Clone, Debug)]
@@ -112,6 +114,16 @@ impl Messenger for ServerMessenger {
112114
.map_err(|e| MessengerError::Custom(e.to_string()))?)
113115
}
114116

117+
async fn send_init_msg(&self) -> Result<(), MessengerError> {
118+
Ok(self
119+
.update_event_sender
120+
.send(UpdateEventMessage::InitStart {
121+
server_name: self.server_name.clone(),
122+
})
123+
.await
124+
.map_err(|e| MessengerError::Custom(e.to_string()))?)
125+
}
126+
115127
fn duplicate(&self) -> Box<dyn Messenger> {
116128
Box::new(self.clone())
117129
}

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ use serde::{
4040
};
4141
use thiserror::Error;
4242
use tokio::signal::ctrl_c;
43-
use tokio::sync::Mutex;
43+
use tokio::sync::{
44+
Mutex,
45+
RwLock,
46+
};
4447
use tracing::{
4548
error,
4649
warn,
@@ -406,6 +409,8 @@ impl ToolManagerBuilder {
406409
let new_tool_specs_clone = new_tool_specs.clone();
407410
let has_new_stuff = Arc::new(AtomicBool::new(false));
408411
let has_new_stuff_clone = has_new_stuff.clone();
412+
let pending = Arc::new(RwLock::new(HashSet::<String>::new()));
413+
let pending_clone = pending.clone();
409414
let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20);
410415
tokio::spawn(async move {
411416
while let Some(msg) = msg_rx.recv().await {
@@ -415,6 +420,7 @@ impl ToolManagerBuilder {
415420
// list calls.
416421
match msg {
417422
UpdateEventMessage::ToolsListResult { server_name, result } => {
423+
pending_clone.write().await.remove(&server_name);
418424
let mut specs = result
419425
.tools
420426
.into_iter()
@@ -464,14 +470,16 @@ impl ToolManagerBuilder {
464470
server_name: _,
465471
result: _,
466472
} => {},
467-
UpdateEventMessage::DisplayTaskEnded => {
468-
load_msg_sender.take();
473+
UpdateEventMessage::InitStart { server_name } => {
474+
pending_clone.write().await.insert(server_name.clone());
475+
if let Some(sender) = &load_msg_sender {
476+
let _ = sender.send(LoadingMsg::Add(server_name));
477+
}
469478
},
470479
}
471480
}
472481
});
473482
for (mut name, init_res) in pre_initialized {
474-
let _ = tx.send(LoadingMsg::Add(name.clone()));
475483
match init_res {
476484
Ok(mut client) => {
477485
let messenger = messenger_builder.build_with_name(client.get_server_name().to_owned());
@@ -592,6 +600,7 @@ impl ToolManagerBuilder {
592600
clients,
593601
prompts,
594602
loading_display_task,
603+
pending_clients: pending,
595604
loading_status_sender,
596605
new_tool_specs,
597606
has_new_stuff,
@@ -638,8 +647,19 @@ pub struct ToolManager {
638647
/// These clients are used to communicate with MCP servers.
639648
pub clients: HashMap<String, Arc<CustomToolClient>>,
640649

650+
#[allow(dead_code)]
651+
/// A list of client names that are still in the process of being initialized
652+
pub pending_clients: Arc<RwLock<HashSet<String>>>,
653+
654+
/// Flag indicating whether new tool specifications have been added since the last update.
655+
/// When set to true, it signals that the tool manager needs to refresh its internal state
656+
/// to incorporate newly available tools from MCP servers.
641657
pub has_new_stuff: Arc<AtomicBool>,
642658

659+
/// Storage for newly discovered tool specifications from MCP servers that haven't yet been
660+
/// integrated into the main tool registry. This field holds a thread-safe reference to a map
661+
/// of server names to their tool specifications and name mappings, allowing concurrent updates
662+
/// from server initialization processes.
643663
new_tool_specs: NewToolSpecs,
644664

645665
/// Cache for prompts collected from different servers.
@@ -1059,6 +1079,10 @@ impl ToolManager {
10591079
);
10601080
Ok(())
10611081
}
1082+
1083+
pub async fn _pending_clients(&self) -> Vec<String> {
1084+
self.pending_clients.read().await.iter().cloned().collect::<Vec<_>>()
1085+
}
10621086
}
10631087

10641088
#[inline]
@@ -1150,7 +1174,7 @@ fn process_tool_specs(
11501174
server_name, msg
11511175
);
11521176
if is_in_display {
1153-
Some(LoadingMsg::Error {
1177+
Some(LoadingMsg::Warn {
11541178
name: server_name.to_string(),
11551179
msg: eyre::eyre!(msg),
11561180
})

crates/chat-cli/src/cli/chat/tools/custom_tool.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ impl CustomToolClient {
9292
server_capabilities,
9393
..
9494
} => {
95+
if let Some(messenger) = &client.messenger {
96+
let _ = messenger.send_init_msg().await;
97+
}
9598
// We'll need to first initialize. This is the handshake every client and server
9699
// needs to do before proceeding to anything else
97100
let cap = client.init().await?;

crates/chat-cli/src/mcp_client/messenger.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ pub trait Messenger: std::fmt::Debug + Send + Sync + 'static {
3333
result: ResourceTemplatesListResult,
3434
) -> Result<(), MessengerError>;
3535

36+
/// Signals to the orchestrator that a server has started initializing
37+
async fn send_init_msg(&self) -> Result<(), MessengerError>;
38+
3639
/// Creates a duplicate of the messenger object
3740
/// This function is used to create a new instance of the messenger with the same configuration
3841
fn duplicate(&self) -> Box<dyn Messenger>;
@@ -68,6 +71,10 @@ impl Messenger for NullMessenger {
6871
Ok(())
6972
}
7073

74+
async fn send_init_msg(&self) -> Result<(), MessengerError> {
75+
Ok(())
76+
}
77+
7178
fn duplicate(&self) -> Box<dyn Messenger> {
7279
Box::new(NullMessenger)
7380
}

0 commit comments

Comments
 (0)