88)
99from claude_agent_sdk .types import (
1010 AssistantMessage ,
11+ Message ,
1112 SystemMessage ,
1213 TextBlock ,
1314 ToolUseBlock ,
2324 get_sdk_config ,
2425)
2526from agent_chat_cli .utils .enums import AgentMessageType , ContentType , ControlCommand
26- from agent_chat_cli .core .mcp_inference import infer_mcp_servers
2727from agent_chat_cli .utils .logger import log_json
28+ from agent_chat_cli .utils .mcp_server_status import MCPServerStatus
2829
2930if TYPE_CHECKING :
3031 from agent_chat_cli .app import AgentChatCLIApp
@@ -46,7 +47,6 @@ def __init__(
4647 self .config = load_config ()
4748 self .session_id = session_id
4849 self .available_servers = get_available_servers ()
49- self .inferred_servers : set [str ] = set ()
5050
5151 self .client : ClaudeSDKClient
5252
@@ -58,78 +58,33 @@ def __init__(
5858 self .interrupting = False
5959
6060 async def start (self ) -> None :
61- # Boot MCP servers lazily
62- if self .config .mcp_server_inference :
63- await self ._initialize_client (mcp_servers = {})
64- else :
65- # Boot MCP servers all at once
66- mcp_servers = {
67- name : config .model_dump ()
68- for name , config in self .available_servers .items ()
69- }
70-
71- await self ._initialize_client (mcp_servers = mcp_servers )
61+ mcp_servers = {
62+ name : config .model_dump () for name , config in self .available_servers .items ()
63+ }
64+
65+ await self ._initialize_client (mcp_servers = mcp_servers )
7266
7367 self ._running = True
7468
7569 while self ._running :
7670 user_input = await self .query_queue .get ()
7771
78- # Check for new convo flags
7972 if isinstance (user_input , ControlCommand ):
8073 if user_input == ControlCommand .NEW_CONVERSATION :
81- self .inferred_servers .clear ()
82-
83- await self .client .disconnect ()
84-
85- # Reset MCP servers based on config settings
86- if self .config .mcp_server_inference :
87- await self ._initialize_client (mcp_servers = {})
88- else :
89- mcp_servers = {
90- name : config .model_dump ()
91- for name , config in self .available_servers .items ()
92- }
93-
94- await self ._initialize_client (mcp_servers = mcp_servers )
95- continue
96-
97- # Infer MCP servers based on user messages in chat
98- if self .config .mcp_server_inference :
99- inference_result = await infer_mcp_servers (
100- user_message = user_input ,
101- available_servers = self .available_servers ,
102- inferred_servers = self .inferred_servers ,
103- session_id = self .session_id ,
104- )
105-
106- # If there are new results, create an updated mcp_server list
107- if inference_result ["new_servers" ]:
108- server_list = ", " .join (inference_result ["new_servers" ])
109-
110- self .app .actions .post_system_message (
111- f"Connecting to { server_list } ..."
112- )
113-
114- await asyncio .sleep (0.1 )
115-
116- # If there's updates, we reinitialize the agent SDK (with the
117- # persisted session_id from the turn, stored in the instance)
11874 await self .client .disconnect ()
11975
12076 mcp_servers = {
12177 name : config .model_dump ()
122- for name , config in inference_result [ "selected_servers" ] .items ()
78+ for name , config in self . available_servers .items ()
12379 }
12480
12581 await self ._initialize_client (mcp_servers = mcp_servers )
82+ continue
12683
12784 self .interrupting = False
12885
129- # Send query
13086 await self .client .query (user_input )
13187
132- # Wait for messages from Claude
13388 async for message in self .client .receive_response ():
13489 if self .interrupting :
13590 continue
@@ -154,7 +109,7 @@ async def _initialize_client(self, mcp_servers: dict) -> None:
154109
155110 await self .client .connect ()
156111
157- async def _handle_message (self , message : Any ) -> None :
112+ async def _handle_message (self , message : Message ) -> None :
158113 if isinstance (message , SystemMessage ):
159114 log_json (message .data )
160115
@@ -164,6 +119,9 @@ async def _handle_message(self, message: Any) -> None:
164119 # When initializing the chat, we store the session_id for later
165120 self .session_id = message .data ["session_id" ]
166121
122+ # Report connected / error status back to UI
123+ MCPServerStatus .update (message .data ["mcp_servers" ])
124+
167125 # Handle streaming messages
168126 if hasattr (message , "event" ):
169127 event = message .event # type: ignore[attr-defined]
@@ -215,7 +173,7 @@ async def _can_use_tool(
215173 self ,
216174 tool_name : str ,
217175 tool_input : dict [str , Any ],
218- context : ToolPermissionContext ,
176+ _context : ToolPermissionContext ,
219177 ) -> PermissionResult :
220178 """Agent SDK handler for tool use permissions"""
221179
0 commit comments