Skip to content

Commit e90f590

Browse files
authored
feat: Sampling (#170)
* Removed unused _setup_client * Implement sampling in client and connectors * Tests * Nicer names in action * Update previous tests * Docstrings * Fixed naming and created other primitives tests * Supported features table * docs
1 parent 60b0a03 commit e90f590

File tree

21 files changed

+468
-39
lines changed

21 files changed

+468
-39
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: Primitives
2+
on:
3+
push:
4+
branches: [main]
5+
pull_request:
6+
branches: [main]
7+
8+
jobs:
9+
primitives:
10+
name: ${{ matrix.primitive }}
11+
runs-on: ubuntu-latest
12+
strategy:
13+
fail-fast: false
14+
matrix:
15+
primitive: [sampling, tools, resources, prompts]
16+
python-version: ["3.11"]
17+
18+
steps:
19+
- uses: actions/checkout@v3
20+
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v4
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
26+
- name: Install uv
27+
run: |
28+
pip install uv
29+
30+
- name: Install dependencies
31+
run: |
32+
# Install project with dev and optional extras similar to unit tests environment
33+
uv pip install --system .[dev,anthropic,openai,search,e2b]
34+
35+
- name: Lint with ruff
36+
run: |
37+
ruff check .
38+
39+
- name: Run integration tests for ${{ matrix.primitive }} primitive
40+
run: |
41+
pytest tests/integration/primitives/test_${{ matrix.primitive }}.py

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@
4444

4545
💬 Get started quickly - chat with your servers on our <b>hosted version</b>! [Try mcp-use chat (beta)](https://chat.mcp-use.com).
4646

47-
# Features
47+
| Supports | |
48+
| :--- | :--- |
49+
| **Primitives** | [![Tools](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/testprimitives.yml?job=tools&label=Tools&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/testprimitives.yml) [![Resources](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/testprimitives.yml?job=resources&label=Resources&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/testprimitives.yml) [![Prompts](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/testprimitives.yml?job=prompts&label=Prompts&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/testprimitives.yml) [![Sampling](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/testprimitives.yml?job=sampling&label=Sampling&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/testprimitives.yml) |
50+
| **Transports** | [![Stdio](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/transportstests.yml?job=stdio&label=Stdio&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/transportstests.yml) [![SSE](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/transportstests.yml?job=sse&label=SSE&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/transportstests.yml) [![Streamable HTTP](https://img.shields.io/github/actions/workflow/status/pietrozullo/mcp-use/transportstests.yml?job=streamableHttp&label=Streamable%20HTTP&style=flat)](https://github.com/pietrozullo/mcp-use/actions/workflows/transportstests.yml) |
51+
52+
## Features
4853

4954
<table>
5055
<tr>

docs/docs.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030
"pages": [
3131
"essentials/configuration",
3232
"essentials/client-configuration",
33-
"essentials/connection-types"
33+
{
34+
"group": "Client Features",
35+
"icon": "list-plus",
36+
"pages": [
37+
"essentials/connection-types",
38+
"essentials/sampling"
39+
]
40+
}
3441
]
3542
},
3643
{

docs/essentials/sampling.mdx

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
---
2+
title: "Sampling"
3+
description: "Enable LLM sampling capabilities for MCP tools"
4+
icon: "pipette"
5+
---
6+
7+
# Sampling
8+
9+
<Info>
10+
Sampling allows MCP tools to request LLM completions during their execution.
11+
</Info>
12+
13+
## Configuration
14+
15+
To enable sampling, provide a `sampling_callback` function when initializing the MCPClient:
16+
17+
```python
18+
from mcp_use.client import MCPClient
19+
from mcp.client.session import ClientSession
20+
from mcp.types import (
21+
CreateMessageRequestParams,
22+
CreateMessageResult,
23+
ErrorData,
24+
TextContent
25+
)
26+
27+
async def sampling_callback(
28+
context: ClientSession,
29+
params: CreateMessageRequestParams
30+
) -> CreateMessageResult | ErrorData:
31+
"""
32+
Your sampling callback implementation.
33+
This function receives a prompt and returns an LLM response.
34+
"""
35+
# Integrate with your LLM of choice (OpenAI, Anthropic, etc.)
36+
response = await your_llm.complete(params.messages[-1].content.text)
37+
38+
return CreateMessageResult(
39+
content=TextContent(text=response, type="text"),
40+
model="your-model-name",
41+
role="assistant"
42+
)
43+
44+
# Initialize client with sampling support
45+
client = MCPClient(
46+
config="config.json",
47+
sampling_callback=sampling_callback
48+
)
49+
```
50+
51+
52+
## Creating Sampling-Enabled Tools
53+
54+
When building MCP servers, tools can request sampling using the context parameter:
55+
56+
```python
57+
from fastmcp import Context, FastMCP
58+
59+
mcp = FastMCP(name="MyServer")
60+
61+
@mcp.tool
62+
async def analyze_sentiment(text: str, ctx: Context) -> str:
63+
"""Analyze the sentiment of text using the client's LLM."""
64+
prompt = f"""Analyze the sentiment of the following text as positive, negative, or neutral.
65+
Just output a single word - 'positive', 'negative', or 'neutral'.
66+
67+
Text to analyze: {text}"""
68+
69+
# Request LLM analysis through sampling
70+
response = await ctx.sample(prompt)
71+
72+
return response.text.strip()
73+
```
74+
75+
## Error Handling
76+
77+
If no sampling callback is provided but a tool requests sampling:
78+
79+
```python
80+
# Without sampling callback
81+
client = MCPClient(config="config.json") # No sampling_callback
82+
83+
# Tool that requires sampling will return an error
84+
result = await session.call_tool("analyze_sentiment", {"text": "Hello"})
85+
# result.isError will be True
86+
```

docs/essentials/server-manager.mdx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
21
---
32
title: 'Server Manager'
43
description: 'Intelligent management of multiple MCP servers and tool discovery'
5-
icon: 'server'
4+
icon: 'server-cog'
65
---
76

87
# Server Manager

mcp_use/client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import warnings
1010
from typing import Any
1111

12+
from mcp.client.session import SamplingFnT
13+
1214
from mcp_use.types.sandbox import SandboxOptions
1315

1416
from .config import create_connector_from_config, load_config_file
@@ -28,6 +30,7 @@ def __init__(
2830
config: str | dict[str, Any] | None = None,
2931
sandbox: bool = False,
3032
sandbox_options: SandboxOptions | None = None,
33+
sampling_callback: SamplingFnT | None = None,
3134
) -> None:
3235
"""Initialize a new MCP client.
3336
@@ -36,13 +39,14 @@ def __init__(
3639
If None, an empty configuration is used.
3740
sandbox: Whether to use sandboxed execution mode for running MCP servers.
3841
sandbox_options: Optional sandbox configuration options.
42+
sampling_callback: Optional sampling callback function.
3943
"""
4044
self.config: dict[str, Any] = {}
4145
self.sandbox = sandbox
4246
self.sandbox_options = sandbox_options
4347
self.sessions: dict[str, MCPSession] = {}
4448
self.active_sessions: list[str] = []
45-
49+
self.sampling_callback = sampling_callback
4650
# Load configuration if provided
4751
if config is not None:
4852
if isinstance(config, str):
@@ -151,7 +155,10 @@ async def create_session(self, server_name: str, auto_initialize: bool = True) -
151155

152156
# Create connector with options
153157
connector = create_connector_from_config(
154-
server_config, sandbox=self.sandbox, sandbox_options=self.sandbox_options
158+
server_config,
159+
sandbox=self.sandbox,
160+
sandbox_options=self.sandbox_options,
161+
sampling_callback=self.sampling_callback,
155162
)
156163

157164
# Create the session

mcp_use/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import json
88
from typing import Any
99

10+
from mcp.client.session import SamplingFnT
11+
1012
from mcp_use.types.sandbox import SandboxOptions
1113

1214
from .connectors import (
@@ -36,6 +38,7 @@ def create_connector_from_config(
3638
server_config: dict[str, Any],
3739
sandbox: bool = False,
3840
sandbox_options: SandboxOptions | None = None,
41+
sampling_callback: SamplingFnT | None = None,
3942
) -> BaseConnector:
4043
"""Create a connector based on server configuration.
4144
This function can be called with just the server_config parameter:
@@ -44,7 +47,7 @@ def create_connector_from_config(
4447
server_config: The server configuration section
4548
sandbox: Whether to use sandboxed execution mode for running MCP servers.
4649
sandbox_options: Optional sandbox configuration options.
47-
50+
sampling_callback: Optional sampling callback function.
4851
Returns:
4952
A configured connector instance
5053
"""
@@ -55,6 +58,7 @@ def create_connector_from_config(
5558
command=server_config["command"],
5659
args=server_config["args"],
5760
env=server_config.get("env", None),
61+
sampling_callback=sampling_callback,
5862
)
5963

6064
# Sandboxed connector
@@ -64,6 +68,7 @@ def create_connector_from_config(
6468
args=server_config["args"],
6569
env=server_config.get("env", None),
6670
e2b_options=sandbox_options,
71+
sampling_callback=sampling_callback,
6772
)
6873

6974
# HTTP connector
@@ -72,6 +77,7 @@ def create_connector_from_config(
7277
base_url=server_config["url"],
7378
headers=server_config.get("headers", None),
7479
auth_token=server_config.get("auth_token", None),
80+
sampling_callback=sampling_callback,
7581
)
7682

7783
# WebSocket connector

mcp_use/connectors/base.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
from abc import ABC, abstractmethod
99
from typing import Any
1010

11-
from mcp import ClientSession
11+
from mcp import ClientSession, Implementation
12+
from mcp.client.session import SamplingFnT
1213
from mcp.shared.exceptions import McpError
1314
from mcp.types import CallToolResult, GetPromptResult, Prompt, ReadResourceResult, Resource, Tool
1415
from pydantic import AnyUrl
1516

17+
import mcp_use
18+
1619
from ..logging import logger
1720
from ..task_managers import ConnectionManager
1821

@@ -23,7 +26,7 @@ class BaseConnector(ABC):
2326
This class defines the interface that all MCP connectors must implement.
2427
"""
2528

26-
def __init__(self):
29+
def __init__(self, sampling_callback: SamplingFnT | None = None):
2730
"""Initialize base connector with common attributes."""
2831
self.client_session: ClientSession | None = None
2932
self._connection_manager: ConnectionManager | None = None
@@ -33,6 +36,16 @@ def __init__(self):
3336
self._connected = False
3437
self._initialized = False # Track if client_session.initialize() has been called
3538
self.auto_reconnect = True # Whether to automatically reconnect on connection loss (not configurable for now)
39+
self.sampling_callback = sampling_callback
40+
41+
@property
42+
def client_info(self) -> Implementation:
43+
"""Get the client info for the connector."""
44+
return Implementation(
45+
name="mcp-use",
46+
version=mcp_use.__version__,
47+
url="https://github.com/mcp-use/mcp-use",
48+
)
3649

3750
@abstractmethod
3851
async def connect(self) -> None:
@@ -143,7 +156,8 @@ async def initialize(self) -> dict[str, Any]:
143156

144157
logger.debug(
145158
f"MCP session initialized with {len(self._tools)} tools, "
146-
"{len(self._resources)} resources, and {len(self._prompts)} prompts"
159+
f"{len(self._resources)} resources, "
160+
f"and {len(self._prompts)} prompts"
147161
)
148162

149163
return result

mcp_use/connectors/http.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import httpx
99
from mcp import ClientSession
10+
from mcp.client.session import SamplingFnT
1011

1112
from ..logging import logger
12-
from ..task_managers import ConnectionManager, SseConnectionManager, StreamableHttpConnectionManager
13+
from ..task_managers import SseConnectionManager, StreamableHttpConnectionManager
1314
from .base import BaseConnector
1415

1516

@@ -27,6 +28,7 @@ def __init__(
2728
headers: dict[str, str] | None = None,
2829
timeout: float = 5,
2930
sse_read_timeout: float = 60 * 5,
31+
sampling_callback: SamplingFnT | None = None,
3032
):
3133
"""Initialize a new HTTP connector.
3234
@@ -36,8 +38,9 @@ def __init__(
3638
headers: Optional additional headers.
3739
timeout: Timeout for HTTP operations in seconds.
3840
sse_read_timeout: Timeout for SSE read operations in seconds.
41+
sampling_callback: Optional sampling callback.
3942
"""
40-
super().__init__()
43+
super().__init__(sampling_callback=sampling_callback)
4144
self.base_url = base_url.rstrip("/")
4245
self.auth_token = auth_token
4346
self.headers = headers or {}
@@ -46,14 +49,6 @@ def __init__(
4649
self.timeout = timeout
4750
self.sse_read_timeout = sse_read_timeout
4851

49-
async def _setup_client(self, connection_manager: ConnectionManager) -> None:
50-
"""Set up the client session with the provided connection manager."""
51-
52-
self._connection_manager = connection_manager
53-
read_stream, write_stream = await self._connection_manager.start()
54-
self.client_session = ClientSession(read_stream, write_stream, sampling_callback=None)
55-
await self.client_session.__aenter__()
56-
5752
async def connect(self) -> None:
5853
"""Establish a connection to the MCP implementation."""
5954
if self._connected:
@@ -76,7 +71,9 @@ async def connect(self) -> None:
7671
read_stream, write_stream = await connection_manager.start()
7772

7873
# Test if this actually works by trying to create a client session and initialize it
79-
test_client = ClientSession(read_stream, write_stream, sampling_callback=None)
74+
test_client = ClientSession(
75+
read_stream, write_stream, sampling_callback=self.sampling_callback, client_info=self.client_info
76+
)
8077
await test_client.__aenter__()
8178

8279
try:
@@ -154,7 +151,12 @@ async def connect(self) -> None:
154151
read_stream, write_stream = await connection_manager.start()
155152

156153
# Create the client session for SSE
157-
self.client_session = ClientSession(read_stream, write_stream, sampling_callback=None)
154+
self.client_session = ClientSession(
155+
read_stream,
156+
write_stream,
157+
sampling_callback=self.sampling_callback,
158+
client_info=self.client_info,
159+
)
158160
await self.client_session.__aenter__()
159161
self.transport_type = "SSE"
160162

0 commit comments

Comments
 (0)