@@ -120,6 +120,12 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
120120 if capability .sampling is not None :
121121 if client_caps .sampling is None :
122122 return False
123+ if capability .sampling .context is not None :
124+ if client_caps .sampling .context is None :
125+ return False
126+ if capability .sampling .tools is not None :
127+ if client_caps .sampling .tools is None :
128+ return False
123129
124130 if capability .elicitation is not None :
125131 if client_caps .elicitation is None :
@@ -234,6 +240,7 @@ async def create_message(
234240 max_tokens: Maximum number of tokens to generate.
235241 system_prompt: Optional system prompt.
236242 include_context: Optional context inclusion setting.
243+ Requires client to have sampling.context capability.
237244 temperature: Optional sampling temperature.
238245 stop_sequences: Optional stop sequences.
239246 metadata: Optional metadata to pass through to the LLM provider.
@@ -251,6 +258,20 @@ async def create_message(
251258 McpError: If tool_use or tool_result blocks are misused when tools are provided.
252259 """
253260
261+ if tools is not None or tool_choice is not None :
262+ has_tools_cap = self .check_client_capability ( \
263+ types .ClientCapabilities (sampling = types .SamplingCapability (tools = types .SamplingToolsCapability ()))
264+ )
265+ if not has_tools_cap :
266+ from mcp .shared .exceptions import McpError
267+
268+ raise McpError (
269+ types .ErrorData (
270+ code = types .INVALID_PARAMS ,
271+ message = "Client does not support sampling tools capability" ,
272+ )
273+ )
274+
254275 if messages and tools :
255276 last_content = messages [- 1 ].content_as_list
256277 has_tool_results = any (c .type == "tool_result" for c in last_content )
0 commit comments