Skip to content

Commit 420ea0a

Browse files
committed
update check_client_capability
1 parent 90993b3 commit 420ea0a

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

src/mcp/server/session.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
120120
if capability.sampling is not None:
121121
if client_caps.sampling is None:
122122
return False
123+
if capability.sampling.context is not None:
124+
if client_caps.sampling.context is None:
125+
return False
126+
if capability.sampling.tools is not None:
127+
if client_caps.sampling.tools is None:
128+
return False
123129

124130
if capability.elicitation is not None:
125131
if client_caps.elicitation is None:
@@ -234,6 +240,7 @@ async def create_message(
234240
max_tokens: Maximum number of tokens to generate.
235241
system_prompt: Optional system prompt.
236242
include_context: Optional context inclusion setting.
243+
Requires client to have sampling.context capability.
237244
temperature: Optional sampling temperature.
238245
stop_sequences: Optional stop sequences.
239246
metadata: Optional metadata to pass through to the LLM provider.
@@ -251,6 +258,20 @@ async def create_message(
251258
McpError: If tool_use or tool_result blocks are misused when tools are provided.
252259
"""
253260

261+
if tools is not None or tool_choice is not None:
262+
has_tools_cap = self.check_client_capability( \
263+
types.ClientCapabilities(sampling=types.SamplingCapability(tools=types.SamplingToolsCapability()))
264+
)
265+
if not has_tools_cap:
266+
from mcp.shared.exceptions import McpError
267+
268+
raise McpError(
269+
types.ErrorData(
270+
code=types.INVALID_PARAMS,
271+
message="Client does not support sampling tools capability",
272+
)
273+
)
274+
254275
if messages and tools:
255276
last_content = messages[-1].content_as_list
256277
has_tool_results = any(c.type == "tool_result" for c in last_content)

tests/server/test_session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,15 @@ async def test_create_message_tool_result_validation():
309309
capabilities=ServerCapabilities(),
310310
),
311311
) as session:
312+
# Set up client params with sampling.tools capability for the test
313+
session._client_params = types.InitializeRequestParams(
314+
protocolVersion=types.LATEST_PROTOCOL_VERSION,
315+
capabilities=types.ClientCapabilities(
316+
sampling=types.SamplingCapability(tools=types.SamplingToolsCapability())
317+
),
318+
clientInfo=types.Implementation(name="test", version="1.0"),
319+
)
320+
312321
tool = types.Tool(name="test_tool", inputSchema={"type": "object"})
313322
text = types.TextContent(type="text", text="hello")
314323
tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={})

0 commit comments

Comments
 (0)