Skip to content

Commit d3fe3fd

Browse files
committed
fixes agent swap
fixes agent swap
1 parent 123decd commit d3fe3fd

File tree

4 files changed

+276
-91
lines changed

4 files changed

+276
-91
lines changed

crates/chat-cli/src/cli/chat/cli/prompts.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use crate::cli::chat::{
2525
ChatSession,
2626
ChatState,
2727
};
28+
use crate::mcp_client::McpClientError;
2829

2930
#[derive(Debug, Error)]
3031
pub enum GetPromptError {
@@ -45,7 +46,9 @@ pub enum GetPromptError {
4546
#[error("Missing channel")]
4647
MissingChannel,
4748
#[error(transparent)]
48-
ServiceError(#[from] rmcp::ServiceError),
49+
McpClient(#[from] McpClientError),
50+
#[error(transparent)]
51+
Service(#[from] rmcp::ServiceError),
4952
}
5053

5154
/// Command-line arguments for prompt operations

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ mod line_tracker;
1010
mod parser;
1111
mod prompt;
1212
mod prompt_parser;
13-
mod server_messenger;
13+
pub mod server_messenger;
1414
#[cfg(unix)]
1515
mod skim_integration;
1616
mod token_counter;
@@ -2424,7 +2424,7 @@ impl ChatSession {
24242424
.set_tool_use_id(tool_use_id.clone())
24252425
.set_tool_name(tool_use.name.clone())
24262426
.utterance_id(self.conversation.message_id().map(|s| s.to_string()));
2427-
match self.conversation.tool_manager.get_tool_from_tool_use(tool_use) {
2427+
match self.conversation.tool_manager.get_tool_from_tool_use(tool_use).await {
24282428
Ok(mut tool) => {
24292429
// Apply non-Q-generated context to tools
24302430
self.contextualize_tool(&mut tool);

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use tokio::sync::{
4949
use tokio::task::JoinHandle;
5050
use tracing::{
5151
error,
52+
info,
5253
warn,
5354
};
5455

@@ -66,7 +67,6 @@ use crate::cli::chat::cli::prompts::GetPromptError;
6667
use crate::cli::chat::consts::DUMMY_TOOL_NAME;
6768
use crate::cli::chat::message::AssistantToolUse;
6869
use crate::cli::chat::server_messenger::{
69-
ServerMessenger,
7070
ServerMessengerBuilder,
7171
UpdateEventMessage,
7272
};
@@ -87,8 +87,8 @@ use crate::database::Database;
8787
use crate::database::settings::Setting;
8888
use crate::mcp_client::messenger::Messenger;
8989
use crate::mcp_client::{
90-
McpClient,
91-
RunningClient,
90+
InitializedMcpClient,
91+
UninitMcpClient,
9292
};
9393
use crate::os::Os;
9494
use crate::telemetry::TelemetryThread;
@@ -267,16 +267,6 @@ impl ToolManagerBuilder {
267267
.map(|(server_name, _)| server_name.clone())
268268
.collect();
269269

270-
let mut clients = HashMap::<String, RunningClient<ServerMessenger>>::new();
271-
let new_tool_specs = self.new_tool_specs;
272-
let has_new_stuff = self.has_new_stuff;
273-
let pending = Arc::new(RwLock::new(HashSet::<String>::new()));
274-
let notify = Arc::new(Notify::new());
275-
let load_record = self.mcp_load_record;
276-
let agent = self.agent.unwrap_or_default();
277-
let database = os.database.clone();
278-
let mut messenger_builder = self.messenger_builder.take();
279-
280270
let pre_initialized = enabled_servers
281271
.iter()
282272
.filter(|(server_name, _)| {
@@ -301,6 +291,20 @@ impl ToolManagerBuilder {
301291
})
302292
.collect::<Vec<_>>();
303293

294+
let mut clients = HashMap::<String, InitializedMcpClient>::new();
295+
let new_tool_specs = self.new_tool_specs;
296+
let has_new_stuff = self.has_new_stuff;
297+
let pending = Arc::new(RwLock::new({
298+
let mut pending = HashSet::<String>::new();
299+
pending.extend(pre_initialized.iter().map(|(name, _)| name.clone()));
300+
pending
301+
}));
302+
let notify = Arc::new(Notify::new());
303+
let load_record = self.mcp_load_record;
304+
let agent = self.agent.unwrap_or_default();
305+
let database = os.database.clone();
306+
let mut messenger_builder = self.messenger_builder.take();
307+
304308
let mut loading_servers = HashMap::<String, Instant>::new();
305309
for (server_name, _) in &pre_initialized {
306310
let init_time = std::time::Instant::now();
@@ -359,7 +363,7 @@ impl ToolManagerBuilder {
359363
.map(|(server_name, server_config)| {
360364
(
361365
server_name.clone(),
362-
McpClient::new(
366+
UninitMcpClient::new(
363367
server_name.clone(),
364368
server_config,
365369
messenger_builder.build_with_name(server_name),
@@ -519,7 +523,7 @@ pub struct ToolManager {
519523

520524
/// Map of server names to their corresponding client instances.
521525
/// These clients are used to communicate with MCP servers.
522-
pub clients: HashMap<String, RunningClient<ServerMessenger>>,
526+
pub clients: HashMap<String, InitializedMcpClient>,
523527

524528
/// A list of client names that are still in the process of being initialized
525529
pub pending_clients: Arc<RwLock<HashSet<String>>>,
@@ -612,7 +616,32 @@ impl ToolManager {
612616
/// function)
613617
/// - Calling load tools
614618
pub async fn swap_agent(&mut self, os: &mut Os, output: &mut impl Write, agent: &Agent) -> eyre::Result<()> {
615-
self.clients.clear();
619+
let to_evict = self.clients.drain().collect::<Vec<_>>();
620+
tokio::spawn(async move {
621+
for (server_name, initialized_client) in to_evict {
622+
info!("Evicting {server_name} due to agent swap");
623+
match initialized_client {
624+
InitializedMcpClient::Pending(handle) => {
625+
let server_name_clone = server_name.clone();
626+
tokio::spawn(async move {
627+
match handle.await {
628+
Ok(Ok(client)) => match client.cancel().await {
629+
Ok(_) => info!("Server {server_name_clone} evicted due to agent swap"),
630+
Err(e) => error!("Server {server_name_clone} has failed to cancel: {e}"),
631+
},
632+
Ok(Err(_)) | Err(_) => {
633+
error!("Server {server_name_clone} has failed to cancel");
634+
},
635+
}
636+
});
637+
},
638+
InitializedMcpClient::Ready(running_service) => match running_service.cancel().await {
639+
Ok(_) => info!("Server {server_name} evicted due to agent swap"),
640+
Err(e) => error!("Server {server_name} has failed to cancel: {e}"),
641+
},
642+
}
643+
}
644+
});
616645

617646
let mut agent_lock = self.agent.lock().await;
618647
*agent_lock = agent.clone();
@@ -624,9 +653,7 @@ impl ToolManager {
624653
let mut new_tool_manager = builder.build(os, Box::new(std::io::sink()), true).await?;
625654
std::mem::swap(self, &mut new_tool_manager);
626655

627-
// we can discard the output here and let background server load take care of getting the
628-
// new tools
629-
let _ = self.load_tools(os, output).await?;
656+
self.load_tools(os, output).await?;
630657

631658
Ok(())
632659
}
@@ -778,7 +805,7 @@ impl ToolManager {
778805
Ok(self.schema.clone())
779806
}
780807

781-
pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result<Tool, ToolResult> {
808+
pub async fn get_tool_from_tool_use(&mut self, value: AssistantToolUse) -> Result<Tool, ToolResult> {
782809
let map_err = |parse_error| ToolResult {
783810
tool_use_id: value.id.clone(),
784811
content: vec![ToolResultContentBlock::Text(format!(
@@ -822,7 +849,7 @@ impl ToolManager {
822849
})
823850
},
824851
}?;
825-
let Some(client) = self.clients.get(server_name) else {
852+
let Some(client) = self.clients.get_mut(server_name) else {
826853
return Err(ToolResult {
827854
tool_use_id: value.id,
828855
content: vec![ToolResultContentBlock::Text(format!(
@@ -832,13 +859,19 @@ impl ToolManager {
832859
});
833860
};
834861

835-
let custom_tool = CustomTool {
862+
let running_service = (*client.get_running_service().await.map_err(|e| ToolResult {
863+
tool_use_id: value.id.clone(),
864+
content: vec![ToolResultContentBlock::Text(format!("Mcp tool client not ready: {e}"))],
865+
status: ToolResultStatus::Error,
866+
})?)
867+
.clone();
868+
869+
Tool::Custom(CustomTool {
836870
name: tool_name.to_owned(),
837-
server_name: server_name.clone(),
838-
client: (*client).clone(),
871+
server_name: server_name.to_owned(),
872+
client: running_service,
839873
params: value.args.as_object().cloned(),
840-
};
841-
Tool::Custom(custom_tool)
874+
})
842875
},
843876
})
844877
}
@@ -934,7 +967,7 @@ impl ToolManager {
934967
}
935968

936969
pub async fn get_prompt(
937-
&self,
970+
&mut self,
938971
name: String,
939972
arguments: Option<Vec<String>>,
940973
) -> Result<GetPromptResult, GetPromptError> {
@@ -996,7 +1029,7 @@ impl ToolManager {
9961029
};
9971030

9981031
let server_name = &bundle.server_name;
999-
let client = self.clients.get(server_name).ok_or(GetPromptError::MissingClient)?;
1032+
let client = self.clients.get_mut(server_name).ok_or(GetPromptError::MissingClient)?;
10001033
let PromptBundle { prompt_get, .. } = bundle;
10011034
let arguments = if let (Some(schema), Some(value)) = (&prompt_get.arguments, &arguments) {
10021035
let params = schema.iter().zip(value.iter()).fold(
@@ -1015,8 +1048,11 @@ impl ToolManager {
10151048
} else {
10161049
None
10171050
};
1051+
10181052
let params = GetPromptRequestParam { name, arguments };
1019-
let resp = client.get_prompt(params).await?;
1053+
let running_service = client.get_running_service().await?;
1054+
let resp = running_service.get_prompt(params).await?;
1055+
10201056
Ok(resp)
10211057
},
10221058
(None, _) => Err(GetPromptError::PromptNotFound(prompt_name)),

0 commit comments

Comments
 (0)