3
3
This module handles connections to one or more MCP servers, including setup,
4
4
initialization, and communication.
5
5
"""
6
-
7
6
import os
8
7
import shutil
9
8
from contextlib import AsyncExitStack
10
- from typing import Dict , List , Any , Optional , Tuple
11
- from rich .console import Console
12
- from rich .panel import Panel
9
+ from typing import Any , Dict , List , Optional , Tuple
10
+
11
+ import mcp .types
12
+ import ollama
13
+ import rich .json
13
14
from mcp import ClientSession , Tool
14
- from mcp .client .stdio import stdio_client , StdioServerParameters
15
+ from mcp .client .session import LoggingFnT , SamplingFnT
15
16
from mcp .client .sse import sse_client
17
+ from mcp .client .stdio import StdioServerParameters , stdio_client
16
18
from mcp .client .streamable_http import streamablehttp_client
19
+ from mcp .shared .context import LifespanContextT , RequestContext
20
+ from rich .console import Console , Group
21
+ from rich .markdown import Markdown
22
+ from rich .panel import Panel
17
23
18
24
from .discovery import process_server_paths , parse_server_configs , auto_discover_servers
19
25
@@ -25,14 +31,16 @@ class ServerConnector:
25
31
tools provided by those servers.
26
32
"""
27
33
28
- def __init__ (self , exit_stack : AsyncExitStack , console : Optional [Console ] = None ):
34
+ def __init__ (self , exit_stack : AsyncExitStack , default_model : str , ollama_client : ollama . AsyncClient , console : Optional [Console ] = None ):
29
35
"""Initialize the ServerConnector.
30
36
31
37
Args:
32
38
exit_stack: AsyncExitStack to manage server connections
33
39
console: Rich console for output (optional)
34
40
"""
35
41
self .exit_stack = exit_stack
42
+ self .ollama = ollama_client
43
+ self .default_model = default_model
36
44
self .console = console or Console ()
37
45
self .sessions = {} # Dict to store multiple sessions
38
46
self .available_tools = [] # List to store all available tools
@@ -97,6 +105,54 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis
97
105
98
106
return self .sessions , self .available_tools , self .enabled_tools
99
107
108
+ def create_log_callback (self , server_name : str ) -> LoggingFnT :
109
+ async def log_callback (params : mcp .types .LoggingMessageNotificationParams ) -> None :
110
+ self .console .log (f"[green]\[{ params .level .upper ()} ] - { server_name } :[/green]" , params .data )
111
+
112
+ return log_callback
113
+
114
+ def create_sampling_callback (self ) -> SamplingFnT :
115
+ async def _sampling_handler (
116
+ context : RequestContext [ClientSession , LifespanContextT ],
117
+ params : mcp .types .CreateMessageRequestParams ,
118
+ ) -> mcp .types .CreateMessageResult | mcp .types .ErrorData :
119
+ messages = [
120
+ {"role" : "system" , "content" : params .systemPrompt }
121
+ ] + [
122
+ {'role' : msg .role , 'content' : msg .content .text } for msg in params .messages
123
+ ]
124
+ self .console .print (Panel (
125
+ Group (* (Panel (Markdown (msg ["content" ]), title = msg ["role" ], ) for msg in messages )),
126
+ title = "🧠 Handling sampling request..." ,
127
+ border_style = "cyan" ,
128
+ expand = False
129
+ ))
130
+ try :
131
+ response = await self .ollama .chat (
132
+ self .default_model ,
133
+ messages ,
134
+ options = {
135
+ "temperature" : params .temperature ,
136
+ "num_predict" : params .maxTokens ,
137
+ "stop" : params .stopSequences ,
138
+ }
139
+ )
140
+ except Exception as e :
141
+ self .console .print_exception ()
142
+ return mcp .types .ErrorData (
143
+ code = mcp .types .INTERNAL_ERROR ,
144
+ message = str (e ),
145
+ )
146
+ else :
147
+ return mcp .CreateMessageResult (
148
+ role = "assistant" ,
149
+ model = "fastmcp-client" ,
150
+ content = mcp .types .TextContent (type = "text" , text = response .message .content ),
151
+ )
152
+
153
+ return _sampling_handler
154
+
155
+
100
156
async def _connect_to_server (self , server : Dict [str , Any ]) -> bool :
101
157
"""Connect to a single MCP server
102
158
@@ -126,7 +182,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
126
182
# Connect using SSE transport
127
183
sse_transport = await self .exit_stack .enter_async_context (sse_client (url , headers = headers ))
128
184
read_stream , write_stream = sse_transport
129
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
185
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
130
186
131
187
elif server_type == "streamable_http" :
132
188
# Connect to Streamable HTTP server
@@ -142,7 +198,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
142
198
streamablehttp_client (url , headers = headers )
143
199
)
144
200
read_stream , write_stream , session_info = transport
145
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
201
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
146
202
147
203
# Store session ID if provided
148
204
if hasattr (session_info , 'session_id' ) and session_info .session_id :
@@ -156,7 +212,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
156
212
157
213
stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
158
214
read_stream , write_stream = stdio_transport
159
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
215
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
160
216
161
217
else :
162
218
# Connect to config-based server using STDIO
@@ -166,7 +222,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
166
222
167
223
stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
168
224
read_stream , write_stream = stdio_transport
169
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
225
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
170
226
171
227
# Initialize the session
172
228
await session .initialize ()
0 commit comments