Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp_client_for_ollama/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, model: str = DEFAULT_MODEL, host: str = DEFAULT_OLLAMA_HOST):
self.console = Console()
self.config_manager = ConfigManager(self.console)
# Initialize the server connector
self.server_connector = ServerConnector(self.exit_stack, self.console)
self.server_connector = ServerConnector(self.exit_stack, default_model=model, ollama_client=self.ollama, console=self.console)
# Initialize the model manager
self.model_manager = ModelManager(console=self.console, default_model=model, ollama=self.ollama)
# Initialize the model config manager
Expand Down
76 changes: 66 additions & 10 deletions mcp_client_for_ollama/server/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
This module handles connections to one or more MCP servers, including setup,
initialization, and communication.
"""

import os
import shutil
from contextlib import AsyncExitStack
from typing import Dict, List, Any, Optional, Tuple
from rich.console import Console
from rich.panel import Panel
from typing import Any, Dict, List, Optional, Tuple

import mcp.types
import ollama
import rich.json
from mcp import ClientSession, Tool
from mcp.client.stdio import stdio_client, StdioServerParameters
from mcp.client.session import LoggingFnT, SamplingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.context import LifespanContextT, RequestContext
from rich.console import Console, Group
from rich.markdown import Markdown
from rich.panel import Panel

from .discovery import process_server_paths, parse_server_configs, auto_discover_servers

Expand All @@ -25,14 +31,16 @@ class ServerConnector:
tools provided by those servers.
"""

def __init__(self, exit_stack: AsyncExitStack, console: Optional[Console] = None):
def __init__(self, exit_stack: AsyncExitStack, default_model: str, ollama_client: ollama.AsyncClient, console: Optional[Console] = None):
"""Initialize the ServerConnector.

Args:
exit_stack: AsyncExitStack to manage server connections
console: Rich console for output (optional)
"""
self.exit_stack = exit_stack
self.ollama = ollama_client
self.default_model = default_model
self.console = console or Console()
self.sessions = {} # Dict to store multiple sessions
self.available_tools = [] # List to store all available tools
Expand Down Expand Up @@ -97,6 +105,54 @@ async def connect_to_servers(self, server_paths=None, config_path=None, auto_dis

return self.sessions, self.available_tools, self.enabled_tools

def create_log_callback(self, server_name: str) -> LoggingFnT:
async def log_callback(params: mcp.types.LoggingMessageNotificationParams) -> None:
self.console.log(f"[green]\[{params.level.upper()}] - {server_name}:[/green]", params.data)

return log_callback

def create_sampling_callback(self) -> SamplingFnT:
async def _sampling_handler(
context: RequestContext[ClientSession, LifespanContextT],
params: mcp.types.CreateMessageRequestParams,
) -> mcp.types.CreateMessageResult | mcp.types.ErrorData:
messages = [
{"role": "system", "content": params.systemPrompt}
] + [
{'role': msg.role, 'content': msg.content.text} for msg in params.messages
]
self.console.print(Panel(
Group(*(Panel(Markdown(msg["content"]), title=msg["role"], ) for msg in messages)),
title="🧠 Handling sampling request...",
border_style="cyan",
expand=False
))
try:
response = await self.ollama.chat(
self.default_model,
messages,
options={
"temperature": params.temperature,
"num_predict": params.maxTokens,
"stop": params.stopSequences,
}
)
except Exception as e:
self.console.print_exception()
return mcp.types.ErrorData(
code=mcp.types.INTERNAL_ERROR,
message=str(e),
)
else:
return mcp.CreateMessageResult(
role="assistant",
model="fastmcp-client",
content=mcp.types.TextContent(type="text", text=response.message.content),
)

return _sampling_handler


async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
"""Connect to a single MCP server

Expand Down Expand Up @@ -126,7 +182,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
# Connect using SSE transport
sse_transport = await self.exit_stack.enter_async_context(sse_client(url, headers=headers))
read_stream, write_stream = sse_transport
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))

elif server_type == "streamable_http":
# Connect to Streamable HTTP server
Expand All @@ -142,7 +198,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:
streamablehttp_client(url, headers=headers)
)
read_stream, write_stream, session_info = transport
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))

# Store session ID if provided
if hasattr(session_info, 'session_id') and session_info.session_id:
Expand All @@ -156,7 +212,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
read_stream, write_stream = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))

else:
# Connect to config-based server using STDIO
Expand All @@ -166,7 +222,7 @@ async def _connect_to_server(self, server: Dict[str, Any]) -> bool:

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
read_stream, write_stream = stdio_transport
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream, logging_callback=self.create_log_callback(server_name), sampling_callback=self.create_sampling_callback()))

# Initialize the session
await session.initialize()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
{name = "Jonathan Löwenstern"}
]
dependencies = [
"mcp>=1.6.0",
"mcp>=1.11.0",
"ollama==0.5.1",
"prompt-toolkit>=3.0.51",
"rich>=14.0.0",
Expand Down
Loading
Loading