Skip to content

Commit 71c4755

Browse files
ochafikfelixweinbergerclaude
authored
Implement SEP-1577 - Sampling With Tools (#1594)
Co-authored-by: Felix Weinberger <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent c51936f commit 71c4755

File tree

9 files changed

+674
-19
lines changed

9 files changed

+674
-19
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -886,8 +886,8 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str:
886886
max_tokens=100,
887887
)
888888

889-
if result.content.type == "text":
890-
return result.content.text
889+
if all(c.type == "text" for c in result.content_as_list):
890+
return "\n".join(c.text for c in result.content_as_list if c.type == "text")
891891
return str(result.content)
892892
```
893893

examples/servers/everything-server/mcp_everything_server/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str:
134134
max_tokens=100,
135135
)
136136

137-
if result.content.type == "text":
138-
model_response = result.content.text
137+
if any(c.type == "text" for c in result.content_as_list):
138+
model_response = "\n".join(c.text for c in result.content_as_list if c.type == "text")
139139
else:
140140
model_response = "No response"
141141

examples/snippets/servers/sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str:
2020
max_tokens=100,
2121
)
2222

23-
if result.content.type == "text":
24-
return result.content.text
23+
if all(c.type == "text" for c in result.content_as_list):
24+
return "\n".join(c.text for c in result.content_as_list if c.type == "text")
2525
return str(result.content)

src/mcp/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@
4141
ResourcesCapability,
4242
ResourceUpdatedNotification,
4343
RootsCapability,
44+
SamplingCapability,
45+
SamplingContextCapability,
4446
SamplingMessage,
47+
SamplingMessageContentBlock,
48+
SamplingToolsCapability,
4549
ServerCapabilities,
4650
ServerNotification,
4751
ServerRequest,
@@ -50,7 +54,10 @@
5054
StopReason,
5155
SubscribeRequest,
5256
Tool,
57+
ToolChoice,
58+
ToolResultContent,
5359
ToolsCapability,
60+
ToolUseContent,
5461
UnsubscribeRequest,
5562
)
5663
from .types import (
@@ -65,6 +72,7 @@
6572
"ClientResult",
6673
"ClientSession",
6774
"ClientSessionGroup",
75+
"CompleteRequest",
6876
"CreateMessageRequest",
6977
"CreateMessageResult",
7078
"ErrorData",
@@ -77,6 +85,7 @@
7785
"InitializedNotification",
7886
"JSONRPCError",
7987
"JSONRPCRequest",
88+
"JSONRPCResponse",
8089
"ListPromptsRequest",
8190
"ListPromptsResult",
8291
"ListResourcesRequest",
@@ -91,12 +100,16 @@
91100
"PromptsCapability",
92101
"ReadResourceRequest",
93102
"ReadResourceResult",
103+
"Resource",
94104
"ResourcesCapability",
95105
"ResourceUpdatedNotification",
96-
"Resource",
97106
"RootsCapability",
107+
"SamplingCapability",
108+
"SamplingContextCapability",
98109
"SamplingMessage",
110+
"SamplingMessageContentBlock",
99111
"SamplingRole",
112+
"SamplingToolsCapability",
100113
"ServerCapabilities",
101114
"ServerNotification",
102115
"ServerRequest",
@@ -107,10 +120,11 @@
107120
"StopReason",
108121
"SubscribeRequest",
109122
"Tool",
123+
"ToolChoice",
124+
"ToolResultContent",
110125
"ToolsCapability",
126+
"ToolUseContent",
111127
"UnsubscribeRequest",
112128
"stdio_client",
113129
"stdio_server",
114-
"CompleteRequest",
115-
"JSONRPCResponse",
116130
]

src/mcp/server/session.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
50+
from mcp.shared.exceptions import McpError
5051
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5152
from mcp.shared.session import (
5253
BaseSession,
@@ -120,6 +121,12 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
120121
if capability.sampling is not None:
121122
if client_caps.sampling is None:
122123
return False
124+
if capability.sampling.context is not None:
125+
if client_caps.sampling.context is None:
126+
return False
127+
if capability.sampling.tools is not None:
128+
if client_caps.sampling.tools is None:
129+
return False
123130

124131
if capability.elicitation is not None:
125132
if client_caps.elicitation is None:
@@ -223,9 +230,75 @@ async def create_message(
223230
stop_sequences: list[str] | None = None,
224231
metadata: dict[str, Any] | None = None,
225232
model_preferences: types.ModelPreferences | None = None,
233+
tools: list[types.Tool] | None = None,
234+
tool_choice: types.ToolChoice | None = None,
226235
related_request_id: types.RequestId | None = None,
227236
) -> types.CreateMessageResult:
228-
"""Send a sampling/create_message request."""
237+
"""Send a sampling/create_message request.
238+
239+
Args:
240+
messages: The conversation messages to send.
241+
max_tokens: Maximum number of tokens to generate.
242+
system_prompt: Optional system prompt.
243+
include_context: Optional context inclusion setting.
244+
Should only be set to "thisServer" or "allServers"
245+
if the client has sampling.context capability.
246+
temperature: Optional sampling temperature.
247+
stop_sequences: Optional stop sequences.
248+
metadata: Optional metadata to pass through to the LLM provider.
249+
model_preferences: Optional model selection preferences.
250+
tools: Optional list of tools the LLM can use during sampling.
251+
Requires client to have sampling.tools capability.
252+
tool_choice: Optional control over tool usage behavior.
253+
Requires client to have sampling.tools capability.
254+
related_request_id: Optional ID of a related request.
255+
256+
Returns:
257+
The sampling result from the client.
258+
259+
Raises:
260+
McpError: If tool_use or tool_result blocks are misused when tools are provided.
261+
"""
262+
263+
if tools is not None or tool_choice is not None:
264+
has_tools_cap = self.check_client_capability(
265+
types.ClientCapabilities(sampling=types.SamplingCapability(tools=types.SamplingToolsCapability()))
266+
)
267+
if not has_tools_cap:
268+
raise McpError(
269+
types.ErrorData(
270+
code=types.INVALID_PARAMS,
271+
message="Client does not support sampling tools capability",
272+
)
273+
)
274+
275+
# Validate tool_use/tool_result message structure per SEP-1577:
276+
# https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577
277+
# This validation runs regardless of whether `tools` is in this request,
278+
# since a tool loop continuation may omit `tools` while still containing
279+
# tool_result content that must match previous tool_use.
280+
if messages:
281+
last_content = messages[-1].content_as_list
282+
has_tool_results = any(c.type == "tool_result" for c in last_content)
283+
284+
previous_content = messages[-2].content_as_list if len(messages) >= 2 else None
285+
has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content)
286+
287+
if has_tool_results:
288+
# Per spec: "SamplingMessage with tool result content blocks
289+
# MUST NOT contain other content types."
290+
if any(c.type != "tool_result" for c in last_content):
291+
raise ValueError("The last message must contain only tool_result content if any is present")
292+
if previous_content is None:
293+
raise ValueError("tool_result requires a previous message containing tool_use")
294+
if not has_previous_tool_use:
295+
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
296+
if has_previous_tool_use and previous_content:
297+
tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
298+
tool_result_ids = {c.toolUseId for c in last_content if c.type == "tool_result"}
299+
if tool_use_ids != tool_result_ids:
300+
raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match")
301+
229302
return await self.send_request(
230303
request=types.ServerRequest(
231304
types.CreateMessageRequest(
@@ -238,6 +311,8 @@ async def create_message(
238311
stopSequences=stop_sequences,
239312
metadata=metadata,
240313
modelPreferences=model_preferences,
314+
tools=tools,
315+
toolChoice=tool_choice,
241316
),
242317
)
243318
),

0 commit comments

Comments
 (0)