7
7
import os
8
8
import shutil
9
9
from contextlib import AsyncExitStack
10
- from typing import Dict , List , Any , Optional , Tuple
11
- from rich .console import Console
12
- from rich .panel import Panel
10
+ from typing import Any , Dict , List , Optional , Tuple
11
+
12
+ import mcp .types
13
+ import ollama
13
14
from mcp import ClientSession , Tool
14
- from mcp .client .stdio import stdio_client , StdioServerParameters
15
+ from mcp .client .session import LoggingFnT , SamplingFnT
15
16
from mcp .client .sse import sse_client
17
+ from mcp .client .stdio import StdioServerParameters , stdio_client
16
18
from mcp .client .streamable_http import streamablehttp_client
19
+ from mcp .shared .context import LifespanContextT , RequestContext
20
+ from rich .console import Console
21
+ from rich .panel import Panel
17
22
18
23
from .discovery import process_server_paths , parse_server_configs , auto_discover_servers
19
24
@@ -25,14 +30,16 @@ class ServerConnector:
25
30
tools provided by those servers.
26
31
"""
27
32
28
- def __init__ (self , exit_stack : AsyncExitStack , console : Optional [Console ] = None ):
33
+ def __init__ (self , exit_stack : AsyncExitStack , default_model : str , ollama_client : ollama . AsyncClient , console : Optional [Console ] = None ):
29
34
"""Initialize the ServerConnector.
30
35
31
36
Args:
32
37
exit_stack: AsyncExitStack to manage server connections
33
38
console: Rich console for output (optional)
34
39
"""
35
40
self .exit_stack = exit_stack
41
+ self .ollama = ollama_client
42
+ self .default_model = default_model
36
43
self .console = console or Console ()
37
44
self .sessions = {} # Dict to store multiple sessions
38
45
self .available_tools = [] # List to store all available tools
@@ -97,6 +104,40 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis
97
104
98
105
return self .sessions , self .available_tools , self .enabled_tools
99
106
107
+ def create_log_callback (self , server_name : str ) -> LoggingFnT :
108
+ async def log_callback (params : mcp .types .LoggingMessageNotificationParams ) -> None :
109
+ self .console .log (f"[green]\[{ params .level .upper ()} ] - { server_name } :[/green]" , params .data )
110
+
111
+ return log_callback
112
+
113
+ def create_sampling_callback (self ) -> SamplingFnT :
114
+ async def _sampling_handler (
115
+ context : RequestContext [ClientSession , LifespanContextT ],
116
+ params : mcp .types .CreateMessageRequestParams ,
117
+ ) -> mcp .types .CreateMessageResult | mcp .types .ErrorData :
118
+ self .console .print ("[cyan]Handling sampling request...[/cyan]" )
119
+ try :
120
+ response = await self .ollama .chat (self .default_model , [{'role' : msg .role , 'content' : msg .content .text } for msg in params .messages ], options = {
121
+ "temperature" : params .temperature ,
122
+ "num_predict" : params .maxTokens ,
123
+ "stop" : params .stopSequences ,
124
+ })
125
+ except Exception as e :
126
+ self .console .print_exception ()
127
+ return mcp .types .ErrorData (
128
+ code = mcp .types .INTERNAL_ERROR ,
129
+ message = str (e ),
130
+ )
131
+ else :
132
+ return mcp .CreateMessageResult (
133
+ role = "assistant" ,
134
+ model = "fastmcp-client" ,
135
+ content = mcp .types .TextContent (type = "text" , text = response .message .content ),
136
+ )
137
+
138
+ return _sampling_handler
139
+
140
+
100
141
async def _connect_to_server (self , server : Dict [str , Any ]) -> bool :
101
142
"""Connect to a single MCP server
102
143
@@ -126,7 +167,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
126
167
# Connect using SSE transport
127
168
sse_transport = await self .exit_stack .enter_async_context (sse_client (url , headers = headers ))
128
169
read_stream , write_stream = sse_transport
129
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
170
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
130
171
131
172
elif server_type == "streamable_http" :
132
173
# Connect to Streamable HTTP server
@@ -142,7 +183,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
142
183
streamablehttp_client (url , headers = headers )
143
184
)
144
185
read_stream , write_stream , session_info = transport
145
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
186
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
146
187
147
188
# Store session ID if provided
148
189
if hasattr (session_info , 'session_id' ) and session_info .session_id :
@@ -156,7 +197,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
156
197
157
198
stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
158
199
read_stream , write_stream = stdio_transport
159
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
200
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
160
201
161
202
else :
162
203
# Connect to config-based server using STDIO
@@ -166,7 +207,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
166
207
167
208
stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
168
209
read_stream , write_stream = stdio_transport
169
- session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream ))
210
+ session = await self .exit_stack .enter_async_context (ClientSession (read_stream , write_stream , logging_callback = self . create_log_callback ( server_name ), sampling_callback = self . create_sampling_callback () ))
170
211
171
212
# Initialize the session
172
213
await session .initialize ()
0 commit comments