Skip to content

Commit 0184cef

Browse files
committed
Resolve #42 -- Add logging and sampling callbacks
1 parent b41b807 commit 0184cef

File tree

4 files changed

+343
-282
lines changed

4 files changed

+343
-282
lines changed

mcp_client_for_ollama/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST):
3737
self.console = Console()
3838
self.config_manager = ConfigManager(self.console)
3939
# Initialize the server connector
40-
self.server_connector = ServerConnector(self.exit_stack, self.console)
40+
self.server_connector = ServerConnector(self.exit_stack, default_model=model, ollama_client=self.ollama, console=self.console)
4141
# Initialize the model manager
4242
self.model_manager = ModelManager(console=self.console, default_model=model, ollama=self.ollama)
4343
# Initialize the model config manager

mcp_client_for_ollama/server/connector.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@
77
import os
88
import shutil
99
from contextlib import AsyncExitStack
10-
from typing import Dict, List, Any, Optional, Tuple
11-
from rich.console import Console
12-
from rich.panel import Panel
10+
from typing import Any, Dict, List, Optional, Tuple
11+
12+
import mcp.types
13+
import ollama
1314
from mcp import ClientSession, Tool
14-
from mcp.client.stdio import stdio_client, StdioServerParameters
15+
from mcp.client.session import LoggingFnT, SamplingFnT
1516
from mcp.client.sse import sse_client
17+
from mcp.client.stdio import StdioServerParameters, stdio_client
1618
from mcp.client.streamable_http import streamablehttp_client
19+
from mcp.shared.context import LifespanContextT, RequestContext
20+
from rich.console import Console
21+
from rich.panel import Panel
1722

1823
from .discovery import process_server_paths, parse_server_configs, auto_discover_servers
1924

@@ -25,14 +30,16 @@ class ServerConnector:
2530
tools provided by those servers.
2631
"""
2732

28-
def __init__(self, exit_stack: AsyncExitStack, console: Optional[Console] = None):
33+
def __init__(self, exit_stack: AsyncExitStack, default_model: str, ollama_client: ollama.AsyncClient, console: Optional[Console] = None):
2934
"""Initialize the ServerConnector.
3035
3136
Args:
3237
exit_stack: AsyncExitStack to manage server connections
3338
console: Rich console for output (optional)
3439
"""
3540
self.exit_stack = exit_stack
41+
self.ollama = ollama_client
42+
self.default_model = default_model
3643
self.console = console or Console()
3744
self.sessions = {} # Dict to store multiple sessions
3845
self.available_tools = [] # List to store all available tools
@@ -97,6 +104,40 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis
97104

98105
return self.sessions, self.available_tools, self.enabled_tools
99106

107+
def create_log_callback(self, server_name: str) -> LoggingFnT:
108+
async def log_callback(params: mcp.types.LoggingMessageNotificationParams) -> None:
109+
self.console.log(f"[green]\[{params.level.upper()}] - {server_name}:[/green]", params.data)
110+
111+
return log_callback
112+
113+
def create_sampling_callback(self) -> SamplingFnT:
114+
async def _sampling_handler(
115+
context: RequestContext[ClientSession, LifespanContextT],
116+
params: mcp.types.CreateMessageRequestParams,
117+
) -> mcp.types.CreateMessageResult | mcp.types.ErrorData:
118+
self.console.print("[cyan]Handling sampling request...[/cyan]")
119+
try:
120+
response = await self.ollama.chat(self.default_model, [{'role': msg.role, 'content': msg.content.text} for msg in params.messages], options={
121+
"temperature": params.temperature,
122+
"num_predict": params.maxTokens,
123+
"stop": params.stopSequences,
124+
})
125+
except Exception as e:
126+
self.console.print_exception()
127+
return mcp.types.ErrorData(
128+
code=mcp.types.INTERNAL_ERROR,
129+
message=str(e),
130+
)
131+
else:
132+
return mcp.CreateMessageResult(
133+
role="assistant",
134+
model="fastmcp-client",
135+
content=mcp.types.TextContent(type="text", text=response.message.content),
136+
)
137+
138+
return _sampling_handler
139+
140+
100141
async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
101142
"""Connect to a single MCP server
102143
@@ -126,7 +167,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
126167
# Connect using SSE transport
127168
sse_transport = await self.exit_stack.enter_async_context(sse_client(url, headers=headers))
128169
read_stream, write_stream = sse_transport
129-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
170+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
130171

131172
elif server_type == "streamable_http":
132173
# Connect to Streamable HTTP server
@@ -142,7 +183,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
142183
streamablehttp_client(url, headers=headers)
143184
)
144185
read_stream, write_stream, session_info = transport
145-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
186+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
146187

147188
# Store session ID if provided
148189
if hasattr(session_info, 'session_id') and session_info.session_id:
@@ -156,7 +197,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
156197

157198
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
158199
read_stream, write_stream = stdio_transport
159-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
200+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
160201

161202
else:
162203
# Connect to config-based server using STDIO
@@ -166,7 +207,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
166207

167208
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
168209
read_stream, write_stream = stdio_transport
169-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
210+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
170211

171212
# Initialize the session
172213
await session.initialize()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ authors = [
99
{name = "Jonathan Löwenstern"}
1010
]
1111
dependencies = [
12-
"mcp>=1.6.0",
12+
"mcp>=1.11.0",
1313
"ollama==0.5.1",
1414
"prompt-toolkit>=3.0.51",
1515
"rich>=14.0.0",

0 commit comments

Comments
 (0)