@@ -49,6 +49,7 @@ use tokio::sync::{
4949use tokio:: task:: JoinHandle ;
5050use tracing:: {
5151 error,
52+ info,
5253 warn,
5354} ;
5455
@@ -66,7 +67,6 @@ use crate::cli::chat::cli::prompts::GetPromptError;
6667use crate :: cli:: chat:: consts:: DUMMY_TOOL_NAME ;
6768use crate :: cli:: chat:: message:: AssistantToolUse ;
6869use crate :: cli:: chat:: server_messenger:: {
69- ServerMessenger ,
7070 ServerMessengerBuilder ,
7171 UpdateEventMessage ,
7272} ;
@@ -87,8 +87,8 @@ use crate::database::Database;
8787use crate :: database:: settings:: Setting ;
8888use crate :: mcp_client:: messenger:: Messenger ;
8989use crate :: mcp_client:: {
90- McpClient ,
91- RunningClient ,
90+ InitializedMcpClient ,
91+ UninitMcpClient ,
9292} ;
9393use crate :: os:: Os ;
9494use crate :: telemetry:: TelemetryThread ;
@@ -267,16 +267,6 @@ impl ToolManagerBuilder {
267267 . map ( |( server_name, _) | server_name. clone ( ) )
268268 . collect ( ) ;
269269
270- let mut clients = HashMap :: < String , RunningClient < ServerMessenger > > :: new ( ) ;
271- let new_tool_specs = self . new_tool_specs ;
272- let has_new_stuff = self . has_new_stuff ;
273- let pending = Arc :: new ( RwLock :: new ( HashSet :: < String > :: new ( ) ) ) ;
274- let notify = Arc :: new ( Notify :: new ( ) ) ;
275- let load_record = self . mcp_load_record ;
276- let agent = self . agent . unwrap_or_default ( ) ;
277- let database = os. database . clone ( ) ;
278- let mut messenger_builder = self . messenger_builder . take ( ) ;
279-
280270 let pre_initialized = enabled_servers
281271 . iter ( )
282272 . filter ( |( server_name, _) | {
@@ -301,6 +291,20 @@ impl ToolManagerBuilder {
301291 } )
302292 . collect :: < Vec < _ > > ( ) ;
303293
294+ let mut clients = HashMap :: < String , InitializedMcpClient > :: new ( ) ;
295+ let new_tool_specs = self . new_tool_specs ;
296+ let has_new_stuff = self . has_new_stuff ;
297+ let pending = Arc :: new ( RwLock :: new ( {
298+ let mut pending = HashSet :: < String > :: new ( ) ;
299+ pending. extend ( pre_initialized. iter ( ) . map ( |( name, _) | name. clone ( ) ) ) ;
300+ pending
301+ } ) ) ;
302+ let notify = Arc :: new ( Notify :: new ( ) ) ;
303+ let load_record = self . mcp_load_record ;
304+ let agent = self . agent . unwrap_or_default ( ) ;
305+ let database = os. database . clone ( ) ;
306+ let mut messenger_builder = self . messenger_builder . take ( ) ;
307+
304308 let mut loading_servers = HashMap :: < String , Instant > :: new ( ) ;
305309 for ( server_name, _) in & pre_initialized {
306310 let init_time = std:: time:: Instant :: now ( ) ;
@@ -359,7 +363,7 @@ impl ToolManagerBuilder {
359363 . map ( |( server_name, server_config) | {
360364 (
361365 server_name. clone ( ) ,
362- McpClient :: new (
366+ UninitMcpClient :: new (
363367 server_name. clone ( ) ,
364368 server_config,
365369 messenger_builder. build_with_name ( server_name) ,
@@ -519,7 +523,7 @@ pub struct ToolManager {
519523
520524 /// Map of server names to their corresponding client instances.
521525 /// These clients are used to communicate with MCP servers.
522- pub clients : HashMap < String , RunningClient < ServerMessenger > > ,
526+ pub clients : HashMap < String , InitializedMcpClient > ,
523527
524528 /// A list of client names that are still in the process of being initialized
525529 pub pending_clients : Arc < RwLock < HashSet < String > > > ,
@@ -612,7 +616,32 @@ impl ToolManager {
612616 /// function)
613617 /// - Calling load tools
614618 pub async fn swap_agent ( & mut self , os : & mut Os , output : & mut impl Write , agent : & Agent ) -> eyre:: Result < ( ) > {
615- self . clients . clear ( ) ;
619+ let to_evict = self . clients . drain ( ) . collect :: < Vec < _ > > ( ) ;
620+ tokio:: spawn ( async move {
621+ for ( server_name, initialized_client) in to_evict {
622+ info ! ( "Evicting {server_name} due to agent swap" ) ;
623+ match initialized_client {
624+ InitializedMcpClient :: Pending ( handle) => {
625+ let server_name_clone = server_name. clone ( ) ;
626+ tokio:: spawn ( async move {
627+ match handle. await {
628+ Ok ( Ok ( client) ) => match client. cancel ( ) . await {
629+ Ok ( _) => info ! ( "Server {server_name_clone} evicted due to agent swap" ) ,
630+ Err ( e) => error ! ( "Server {server_name_clone} has failed to cancel: {e}" ) ,
631+ } ,
632+ Ok ( Err ( _) ) | Err ( _) => {
633+ error ! ( "Server {server_name_clone} has failed to cancel" ) ;
634+ } ,
635+ }
636+ } ) ;
637+ } ,
638+ InitializedMcpClient :: Ready ( running_service) => match running_service. cancel ( ) . await {
639+ Ok ( _) => info ! ( "Server {server_name} evicted due to agent swap" ) ,
640+ Err ( e) => error ! ( "Server {server_name} has failed to cancel: {e}" ) ,
641+ } ,
642+ }
643+ }
644+ } ) ;
616645
617646 let mut agent_lock = self . agent . lock ( ) . await ;
618647 * agent_lock = agent. clone ( ) ;
@@ -624,9 +653,7 @@ impl ToolManager {
624653 let mut new_tool_manager = builder. build ( os, Box :: new ( std:: io:: sink ( ) ) , true ) . await ?;
625654 std:: mem:: swap ( self , & mut new_tool_manager) ;
626655
627- // we can discard the output here and let background server load take care of getting the
628- // new tools
629- let _ = self . load_tools ( os, output) . await ?;
656+ self . load_tools ( os, output) . await ?;
630657
631658 Ok ( ( ) )
632659 }
@@ -778,7 +805,7 @@ impl ToolManager {
778805 Ok ( self . schema . clone ( ) )
779806 }
780807
781- pub fn get_tool_from_tool_use ( & self , value : AssistantToolUse ) -> Result < Tool , ToolResult > {
808+ pub async fn get_tool_from_tool_use ( & mut self , value : AssistantToolUse ) -> Result < Tool , ToolResult > {
782809 let map_err = |parse_error| ToolResult {
783810 tool_use_id : value. id . clone ( ) ,
784811 content : vec ! [ ToolResultContentBlock :: Text ( format!(
@@ -822,7 +849,7 @@ impl ToolManager {
822849 } )
823850 } ,
824851 } ?;
825- let Some ( client) = self . clients . get ( server_name) else {
852+ let Some ( client) = self . clients . get_mut ( server_name) else {
826853 return Err ( ToolResult {
827854 tool_use_id : value. id ,
828855 content : vec ! [ ToolResultContentBlock :: Text ( format!(
@@ -832,13 +859,19 @@ impl ToolManager {
832859 } ) ;
833860 } ;
834861
835- let custom_tool = CustomTool {
862+ let running_service = ( * client. get_running_service ( ) . await . map_err ( |e| ToolResult {
863+ tool_use_id : value. id . clone ( ) ,
864+ content : vec ! [ ToolResultContentBlock :: Text ( format!( "Mcp tool client not ready: {e}" ) ) ] ,
865+ status : ToolResultStatus :: Error ,
866+ } ) ?)
867+ . clone ( ) ;
868+
869+ Tool :: Custom ( CustomTool {
836870 name : tool_name. to_owned ( ) ,
837- server_name : server_name. clone ( ) ,
838- client : ( * client ) . clone ( ) ,
871+ server_name : server_name. to_owned ( ) ,
872+ client : running_service ,
839873 params : value. args . as_object ( ) . cloned ( ) ,
840- } ;
841- Tool :: Custom ( custom_tool)
874+ } )
842875 } ,
843876 } )
844877 }
@@ -934,7 +967,7 @@ impl ToolManager {
934967 }
935968
936969 pub async fn get_prompt (
937- & self ,
970+ & mut self ,
938971 name : String ,
939972 arguments : Option < Vec < String > > ,
940973 ) -> Result < GetPromptResult , GetPromptError > {
@@ -996,7 +1029,7 @@ impl ToolManager {
9961029 } ;
9971030
9981031 let server_name = & bundle. server_name ;
999- let client = self . clients . get ( server_name) . ok_or ( GetPromptError :: MissingClient ) ?;
1032+ let client = self . clients . get_mut ( server_name) . ok_or ( GetPromptError :: MissingClient ) ?;
10001033 let PromptBundle { prompt_get, .. } = bundle;
10011034 let arguments = if let ( Some ( schema) , Some ( value) ) = ( & prompt_get. arguments , & arguments) {
10021035 let params = schema. iter ( ) . zip ( value. iter ( ) ) . fold (
@@ -1015,8 +1048,11 @@ impl ToolManager {
10151048 } else {
10161049 None
10171050 } ;
1051+
10181052 let params = GetPromptRequestParam { name, arguments } ;
1019- let resp = client. get_prompt ( params) . await ?;
1053+ let running_service = client. get_running_service ( ) . await ?;
1054+ let resp = running_service. get_prompt ( params) . await ?;
1055+
10201056 Ok ( resp)
10211057 } ,
10221058 ( None , _) => Err ( GetPromptError :: PromptNotFound ( prompt_name) ) ,
0 commit comments