Skip to content

Commit 4019b29

Browse files
committed
Resolve #42 -- Add logging and sampling callbacks
1 parent b41b807 commit 4019b29

File tree

4 files changed

+359
-283
lines changed

4 files changed

+359
-283
lines changed

mcp_client_for_ollama/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST):
3737
self.console = Console()
3838
self.config_manager = ConfigManager(self.console)
3939
# Initialize the server connector
40-
self.server_connector = ServerConnector(self.exit_stack, self.console)
40+
self.server_connector = ServerConnector(self.exit_stack, default_model=model, ollama_client=self.ollama, console=self.console)
4141
# Initialize the model manager
4242
self.model_manager = ModelManager(console=self.console, default_model=model, ollama=self.ollama)
4343
# Initialize the model config manager

mcp_client_for_ollama/server/connector.py

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,23 @@
33
This module handles connections to one or more MCP servers, including setup,
44
initialization, and communication.
55
"""
6-
76
import os
87
import shutil
98
from contextlib import AsyncExitStack
10-
from typing import Dict, List, Any, Optional, Tuple
11-
from rich.console import Console
12-
from rich.panel import Panel
9+
from typing import Any, Dict, List, Optional, Tuple
10+
11+
import mcp.types
12+
import ollama
13+
import rich.json
1314
from mcp import ClientSession, Tool
14-
from mcp.client.stdio import stdio_client, StdioServerParameters
15+
from mcp.client.session import LoggingFnT, SamplingFnT
1516
from mcp.client.sse import sse_client
17+
from mcp.client.stdio import StdioServerParameters, stdio_client
1618
from mcp.client.streamable_http import streamablehttp_client
19+
from mcp.shared.context import LifespanContextT, RequestContext
20+
from rich.console import Console, Group
21+
from rich.markdown import Markdown
22+
from rich.panel import Panel
1723

1824
from .discovery import process_server_paths, parse_server_configs, auto_discover_servers
1925

@@ -25,14 +31,16 @@ class ServerConnector:
2531
tools provided by those servers.
2632
"""
2733

28-
def __init__(self, exit_stack: AsyncExitStack, console: Optional[Console] = None):
34+
def __init__(self, exit_stack: AsyncExitStack, default_model: str, ollama_client: ollama.AsyncClient, console: Optional[Console] = None):
2935
"""Initialize the ServerConnector.
3036
3137
Args:
3238
exit_stack: AsyncExitStack to manage server connections
3339
console: Rich console for output (optional)
3440
"""
3541
self.exit_stack = exit_stack
42+
self.ollama = ollama_client
43+
self.default_model = default_model
3644
self.console = console or Console()
3745
self.sessions = {} # Dict to store multiple sessions
3846
self.available_tools = [] # List to store all available tools
@@ -97,6 +105,54 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis
97105

98106
return self.sessions, self.available_tools, self.enabled_tools
99107

108+
def create_log_callback(self, server_name: str) -> LoggingFnT:
109+
async def log_callback(params: mcp.types.LoggingMessageNotificationParams) -> None:
110+
self.console.log(f"[green]\[{params.level.upper()}] - {server_name}:[/green]", params.data)
111+
112+
return log_callback
113+
114+
def create_sampling_callback(self) -> SamplingFnT:
115+
async def _sampling_handler(
116+
context: RequestContext[ClientSession, LifespanContextT],
117+
params: mcp.types.CreateMessageRequestParams,
118+
) -> mcp.types.CreateMessageResult | mcp.types.ErrorData:
119+
messages = [
120+
{"role": "system", "content": params.systemPrompt}
121+
] + [
122+
{'role': msg.role, 'content': msg.content.text} for msg in params.messages
123+
]
124+
self.console.print(Panel(
125+
Group(*(Panel(Markdown(msg["content"]), title=msg["role"], ) for msg in messages)),
126+
title="🧠 Handling sampling request...",
127+
border_style="cyan",
128+
expand=False
129+
))
130+
try:
131+
response = await self.ollama.chat(
132+
self.default_model,
133+
messages,
134+
options={
135+
"temperature": params.temperature,
136+
"num_predict": params.maxTokens,
137+
"stop": params.stopSequences,
138+
}
139+
)
140+
except Exception as e:
141+
self.console.print_exception()
142+
return mcp.types.ErrorData(
143+
code=mcp.types.INTERNAL_ERROR,
144+
message=str(e),
145+
)
146+
else:
147+
return mcp.CreateMessageResult(
148+
role="assistant",
149+
model="fastmcp-client",
150+
content=mcp.types.TextContent(type="text", text=response.message.content),
151+
)
152+
153+
return _sampling_handler
154+
155+
100156
async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
101157
"""Connect to a single MCP server
102158
@@ -126,7 +182,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
126182
# Connect using SSE transport
127183
sse_transport = await self.exit_stack.enter_async_context(sse_client(url, headers=headers))
128184
read_stream, write_stream = sse_transport
129-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
185+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
130186

131187
elif server_type == "streamable_http":
132188
# Connect to Streamable HTTP server
@@ -142,7 +198,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
142198
streamablehttp_client(url, headers=headers)
143199
)
144200
read_stream, write_stream, session_info = transport
145-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
201+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
146202

147203
# Store session ID if provided
148204
if hasattr(session_info, 'session_id') and session_info.session_id:
@@ -156,7 +212,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
156212

157213
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
158214
read_stream, write_stream = stdio_transport
159-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
215+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
160216

161217
else:
162218
# Connect to config-based server using STDIO
@@ -166,7 +222,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
166222

167223
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
168224
read_stream, write_stream = stdio_transport
169-
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
225+
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))
170226

171227
# Initialize the session
172228
await session.initialize()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ authors = [
99
{name = "Jonathan Löwenstern"}
1010
]
1111
dependencies = [
12-
"mcp>=1.6.0",
12+
"mcp>=1.11.0",
1313
"ollama==0.5.1",
1414
"prompt-toolkit>=3.0.51",
1515
"rich>=14.0.0",

0 commit comments

Comments
 (0)