Skip to content

Commit 550b316

Browse files
committed
Resolve #42 -- Add logging and sampling callbacks
1 parent b41b807 commit 550b316

File tree

4 files changed

+350
-283
lines changed

4 files changed

+350
-283
lines changed

mcp_client_for_ollama/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST):
3737
self.console = Console()
3838
self.config_manager = ConfigManager(self.console)
3939
# Initialize the server connector
40-
self.server_connector = ServerConnector(self.exit_stack, self.console)
40+
self.server_connector = ServerConnector(self.exit_stack, default_model=model, ollama_client=self.ollama, console=self.console)
4141
# Initialize the model manager
4242
self.model_manager = ModelManager(console=self.console, default_model=model, ollama=self.ollama)
4343
# Initialize the model config manager

mcp_client_for_ollama/server/connector.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,23 @@
33
This module handles connections to one or more MCP servers, including setup,
44
initialization, and communication.
55
"""
6-
76
import os
87
import shutil
98
from contextlib import AsyncExitStack
10-
from typing import Dict, List, Any, Optional, Tuple
11-
from rich.console import Console
12-
from rich.panel import Panel
9+
from typing import Any, Dict, List, Optional, Tuple
10+
11+
import mcp.types
12+
import ollama
13+
import rich.json
1314
from mcp import ClientSession, Tool
14-
from mcp.client.stdio import stdio_client, StdioServerParameters
15+
from mcp.client.session import LoggingFnT, SamplingFnT
1516
from mcp.client.sse import sse_client
17+
from mcp.client.stdio import StdioServerParameters, stdio_client
1618
from mcp.client.streamable_http import streamablehttp_client
19+
from mcp.shared.context import LifespanContextT, RequestContext
20+
from rich.console import Console, Group
21+
from rich.markdown import Markdown
22+
from rich.panel import Panel
1723

1824
from .discovery import process_server_paths, parse_server_configs, auto_discover_servers
1925

@@ -25,14 +31,16 @@ class ServerConnector:
2531
tools provided by those servers.
2632
"""
2733

28-
def __init__(self, exit_stack: AsyncExitStack, console: Optional[Console] = None):
34+
def __init__(self, exit_stack: AsyncExitStack, default_model: str, ollama_client: ollama.AsyncClient, console: Optional[Console] = None):
2935
"""Initialize the ServerConnector.
3036
3137
Args:
3238
exit_stack: AsyncExitStack to manage server connections
3339
console: Rich console for output (optional)
3440
"""
3541
self.exit_stack = exit_stack
42+
self.ollama = ollama_client
43+
self.default_model = default_model
3644
self.console = console or Console()
3745
self.sessions = {} # Dict to store multiple sessions
3846
self.available_tools = [] # List to store all available tools
@@ -97,6 +105,45 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis
97105

98106
return self.sessions, self.available_tools, self.enabled_tools
99107

108+
def create_log_callback(self, server_name: str) -> LoggingFnT:
109+
async def log_callback(params: mcp.types.LoggingMessageNotificationParams) -> None:
110+
self.console.log(f"[green]\[{params.level.upper()}] - {server_name}:[/green]", params.data)
111+
112+
return log_callback
113+
114+
def create_sampling_callback(self) -> SamplingFnT:
115+
async def _sampling_handler(
116+
context: RequestContext[ClientSession, LifespanContextT],
117+
params: mcp.types.CreateMessageRequestParams,
118+
) -> mcp.types.CreateMessageResult | mcp.types.ErrorData:
119+
messages = [{"role": "system", "content": params.systemPrompt}] + [{'role': msg.role, 'content': msg.content.text} for msg in params.messages]
120+
self.console.print(Panel(Group(*(Panel(Markdown(msg["content"]), title=msg["role"], ) for msg in messages)), title="🧠 Handling sampling request...", border_style="cyan", expand=False))
121+
try:
122+
response = await self.ollama.chat(
123+
self.default_model,
124+
messages,
125+
options={
126+
"temperature": params.temperature,
127+
"num_predict": params.maxTokens,
128+
"stop": params.stopSequences,
129+
}
130+
)
131+
except Exception as e:
132+
self.console.print_exception()
133+
return mcp.types.ErrorData(
134+
code=mcp.types.INTERNAL_ERROR,
135+
message=str(e),
136+
)
137+
else:
138+
return mcp.CreateMessageResult(
139+
role="assistant",
140+
model="fastmcp-client",
141+
content=mcp.types.TextContent(type="text", text=response.message.content),
142+
)
143+
144+
return _sampling_handler
145+
146+
100147
async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
101148
"""Connect to a single MCP server
102149
@@ -126,7 +173,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
126173
# Connect using SSE transport
127174
sse_transport = await self.exit_stack.enter_async_context(sse_client(url, headers=headers))
128175
read_stream, write_stream = sse_transport
129-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
176+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
130177

131178
elif server_type == "streamable_http":
132179
# Connect to Streamable HTTP server
@@ -142,7 +189,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
142189
streamablehttp_client(url, headers=headers)
143190
)
144191
read_stream, write_stream, session_info = transport
145-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
192+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
146193

147194
# Store session ID if provided
148195
if hasattr(session_info, 'session_id') and session_info.session_id:
@@ -156,7 +203,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
156203

157204
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
158205
read_stream, write_stream = stdio_transport
159-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
206+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
160207

161208
else:
162209
# Connect to config-based server using STDIO
@@ -166,7 +213,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
166213

167214
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
168215
read_stream, write_stream = stdio_transport
169-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
216+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
170217

171218
# Initialize the session
172219
await session.initialize()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ authors = [
99
{name = "Jonathan Löwenstern"}
1010
]
1111
dependencies = [
12-
"mcp>=1.6.0",
12+
"mcp>=1.11.0",
1313
"ollama==0.5.1",
1414
"prompt-toolkit>=3.0.51",
1515
"rich>=14.0.0",

0 commit comments

Comments
 (0)