3
3
This module handles connections to one or more MCP servers, including setup,
4
4
initialization, and communication.
5
5
"""
6
-
7
6
import os
8
7
import shutil
9
8
from contextlib import AsyncExitStack
10
- from typing import Dict , List , Any , Optional , Tuple
11
- from rich .console import Console
12
- from rich .panel import Panel
9
+ from typing import Any , Dict , List , Optional , Tuple
10
+
11
+ import mcp .types
12
+ import ollama
13
+ import rich .json
13
14
from mcp import ClientSession , Tool
14
- from mcp .client .stdio import stdio_client , StdioServerParameters
15
+ from mcp .client .session import LoggingFnT , SamplingFnT
15
16
from mcp .client .sse import sse_client
17
+ from mcp .client .stdio import StdioServerParameters , stdio_client
16
18
from mcp .client .streamable_http import streamablehttp_client
19
+ from mcp .shared .context import LifespanContextT , RequestContext
20
+ from rich .console import Console , Group
21
+ from rich .markdown import Markdown
22
+ from rich .panel import Panel
17
23
18
24
from .discovery import process_server_paths , parse_server_configs , auto_discover_servers
19
25
@@ -25,14 +31,16 @@ class ServerConnector:
25
31
tools provided by those servers.
26
32
"""
27
33
28
- def __init__ (self , exit_stack : AsyncExitStack , console : Optional [Console ] = None ):
34
+ def __init__ (self , exit_stack : AsyncExitStack , default_model : str , ollama_client : ollama . AsyncClient , console : Optional [Console ] = None ):
29
35
"""Initialize the ServerConnector.
30
36
31
37
Args:
32
38
exit_stack: AsyncExitStack to manage server connections
33
39
console: Rich console for output (optional)
34
40
"""
35
41
self .exit_stack = exit_stack
42
+ self .ollama = ollama_client
43
+ self .default_model = default_model
36
44
self .console = console or Console ()
37
45
self .sessions = {} # Dict to store multiple sessions
38
46
self .available_tools = [] # List to store all available tools
@@ -97,6 +105,45 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis
97
105
98
106
return self .sessions , self .available_tools , self .enabled_tools
99
107
108
+ def create_log_callback (self , server_name : str ) -> LoggingFnT :
109
+ async def log_callback (params : mcp .types .LoggingMessageNotificationParams ) -> None :
110
+ self .console .log (f"[green]\[{ params .level .upper ()} ] - { server_name } :[/green]" , params .data )
111
+
112
+ return log_callback
113
+
114
+ def create_sampling_callback (self ) -> SamplingFnT :
115
+ async def _sampling_handler (
116
+ context : RequestContext [ClientSession , LifespanContextT ],
117
+ params : mcp .types .CreateMessageRequestParams ,
118
+ ) -> mcp .types .CreateMessageResult | mcp .types .ErrorData :
119
+ messages = [{"role" : "system" , "content" : params .systemPrompt }] + [{'role' : msg .role , 'content' : msg .content .text } for msg in params .messages ]
120
+ self .console .print (Panel (Group (* (Panel (Markdown (msg ["content" ]), title = msg ["role" ], ) for msg in messages )), title = "🧠 Handling sampling request..." , border_style = "cyan" , expand = False ))
121
+ try :
122
+ response = await self .ollama .chat (
123
+ self .default_model ,
124
+ messages ,
125
+ options = {
126
+ "temperature" : params .temperature ,
127
+ "num_predict" : params .maxTokens ,
128
+ "stop" : params .stopSequences ,
129
+ }
130
+ )
131
+ except Exception as e :
132
+ self .console .print_exception ()
133
+ return mcp .types .ErrorData (
134
+ code = mcp .types .INTERNAL_ERROR ,
135
+ message = str (e ),
136
+ )
137
+ else :
138
+ return mcp .CreateMessageResult (
139
+ role = "assistant" ,
140
+ model = "fastmcp-client" ,
141
+ content = mcp .types .TextContent (type = "text" , text = response .message .content ),
142
+ )
143
+
144
+ return _sampling_handler
145
+
146
+
100
147
async def _connect_to_server (self , server : Dict [str , Any ]) -> bool :
101
148
"""Connect to a single MCP server
102
149
@@ -126,7 +173,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
126
173
# Connect using SSE transport
127
174
sse_transport = await self .exit_stack .enter_async_context (sse_client (url , headers = headers ))
128
175
read_stream , write_stream = sse_transport
129
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
176
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
130
177
131
178
elif server_type == "streamable_http" :
132
179
# Connect to Streamable HTTP server
@@ -142,7 +189,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
142
189
streamablehttp_client (url , headers = headers )
143
190
)
144
191
read_stream , write_stream , session_info = transport
145
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
192
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
146
193
147
194
# Store session ID if provided
148
195
if hasattr (session_info , 'session_id' ) and session_info .session_id :
@@ -156,7 +203,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
156
203
157
204
stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
158
205
read_stream , write_stream = stdio_transport
159
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
206
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
160
207
161
208
else :
162
209
# Connect to config-based server using STDIO
@@ -166,7 +213,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
166
213
167
214
stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
168
215
read_stream , write_stream = stdio_transport
169
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
216
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
170
217
171
218
# Initialize the session
172
219
await session .initialize ()
0 commit comments