@@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4747
4848import mcp .types as types
4949from mcp .server .models import InitializationOptions
50+ from mcp .shared .exceptions import McpError
5051from mcp .shared .message import ServerMessageMetadata , SessionMessage
5152from mcp .shared .session import (
5253 BaseSession ,
@@ -120,6 +121,12 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
120121 if capability .sampling is not None :
121122 if client_caps .sampling is None :
122123 return False
124+ if capability .sampling .context is not None :
125+ if client_caps .sampling .context is None :
126+ return False
127+ if capability .sampling .tools is not None :
128+ if client_caps .sampling .tools is None :
129+ return False
123130
124131 if capability .elicitation is not None :
125132 if client_caps .elicitation is None :
@@ -223,9 +230,75 @@ async def create_message(
223230 stop_sequences : list [str ] | None = None ,
224231 metadata : dict [str , Any ] | None = None ,
225232 model_preferences : types .ModelPreferences | None = None ,
233+ tools : list [types .Tool ] | None = None ,
234+ tool_choice : types .ToolChoice | None = None ,
226235 related_request_id : types .RequestId | None = None ,
227236 ) -> types .CreateMessageResult :
228- """Send a sampling/create_message request."""
237+ """Send a sampling/create_message request.
238+
239+ Args:
240+ messages: The conversation messages to send.
241+ max_tokens: Maximum number of tokens to generate.
242+ system_prompt: Optional system prompt.
243+ include_context: Optional context inclusion setting.
244+ Should only be set to "thisServer" or "allServers"
245+ if the client has sampling.context capability.
246+ temperature: Optional sampling temperature.
247+ stop_sequences: Optional stop sequences.
248+ metadata: Optional metadata to pass through to the LLM provider.
249+ model_preferences: Optional model selection preferences.
250+ tools: Optional list of tools the LLM can use during sampling.
251+ Requires client to have sampling.tools capability.
252+ tool_choice: Optional control over tool usage behavior.
253+ Requires client to have sampling.tools capability.
254+ related_request_id: Optional ID of a related request.
255+
256+ Returns:
257+ The sampling result from the client.
258+
259+ Raises:
260+ McpError: If tool_use or tool_result blocks are misused when tools are provided.
261+ """
262+
263+ if tools is not None or tool_choice is not None :
264+ has_tools_cap = self .check_client_capability (
265+ types .ClientCapabilities (sampling = types .SamplingCapability (tools = types .SamplingToolsCapability ()))
266+ )
267+ if not has_tools_cap :
268+ raise McpError (
269+ types .ErrorData (
270+ code = types .INVALID_PARAMS ,
271+ message = "Client does not support sampling tools capability" ,
272+ )
273+ )
274+
275+ # Validate tool_use/tool_result message structure per SEP-1577:
276+ # https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577
277+ # This validation runs regardless of whether `tools` is in this request,
278+ # since a tool loop continuation may omit `tools` while still containing
279+ # tool_result content that must match previous tool_use.
280+ if messages :
281+ last_content = messages [- 1 ].content_as_list
282+ has_tool_results = any (c .type == "tool_result" for c in last_content )
283+
284+ previous_content = messages [- 2 ].content_as_list if len (messages ) >= 2 else None
285+ has_previous_tool_use = previous_content and any (c .type == "tool_use" for c in previous_content )
286+
287+ if has_tool_results :
288+ # Per spec: "SamplingMessage with tool result content blocks
289+ # MUST NOT contain other content types."
290+ if any (c .type != "tool_result" for c in last_content ):
291+ raise ValueError ("The last message must contain only tool_result content if any is present" )
292+ if previous_content is None :
293+ raise ValueError ("tool_result requires a previous message containing tool_use" )
294+ if not has_previous_tool_use :
295+ raise ValueError ("tool_result blocks do not match any tool_use in the previous message" )
296+ if has_previous_tool_use and previous_content :
297+ tool_use_ids = {c .id for c in previous_content if c .type == "tool_use" }
298+ tool_result_ids = {c .toolUseId for c in last_content if c .type == "tool_result" }
299+ if tool_use_ids != tool_result_ids :
300+ raise ValueError ("ids of tool_result blocks and tool_use blocks from previous message do not match" )
301+
229302 return await self .send_request (
230303 request = types .ServerRequest (
231304 types .CreateMessageRequest (
@@ -238,6 +311,8 @@ async def create_message(
238311 stopSequences = stop_sequences ,
239312 metadata = metadata ,
240313 modelPreferences = model_preferences ,
314+ tools = tools ,
315+ toolChoice = tool_choice ,
241316 ),
242317 )
243318 ),
0 commit comments