Skip to content

Commit ee04b6d

Browse files
committed
apply review suggestions
1 parent 4103c93 commit ee04b6d

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

src/mcp/server/session.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
50+
from mcp.shared.exceptions import McpError
5051
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5152
from mcp.shared.session import (
5253
BaseSession,
@@ -264,27 +265,34 @@ async def create_message(
264265
types.ClientCapabilities(sampling=types.SamplingCapability(tools=types.SamplingToolsCapability()))
265266
)
266267
if not has_tools_cap:
267-
from mcp.shared.exceptions import McpError
268-
269268
raise McpError(
270269
types.ErrorData(
271270
code=types.INVALID_PARAMS,
272271
message="Client does not support sampling tools capability",
273272
)
274273
)
275274

276-
if messages and tools:
275+
# Validate tool_use/tool_result message structure per SEP-1577:
276+
# https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577
277+
# This validation runs regardless of whether `tools` is in this request,
278+
# since a tool loop continuation may omit `tools` while still containing
279+
# tool_result content that must match previous tool_use.
280+
if messages:
277281
last_content = messages[-1].content_as_list
278282
has_tool_results = any(c.type == "tool_result" for c in last_content)
279283

280284
previous_content = messages[-2].content_as_list if len(messages) >= 2 else None
281285
has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content)
282286

283287
if has_tool_results:
288+
# Per spec: "SamplingMessage with tool result content blocks
289+
# MUST NOT contain other content types."
284290
if any(c.type != "tool_result" for c in last_content):
285291
raise ValueError("The last message must contain only tool_result content if any is present")
292+
if previous_content is None:
293+
raise ValueError("tool_result requires a previous message containing tool_use")
286294
if not has_previous_tool_use:
287-
raise ValueError("tool_result blocks are not matching any tool_use from the previous message")
295+
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
288296
if has_previous_tool_use and previous_content:
289297
tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
290298
tool_result_ids = {c.toolUseId for c in last_content if c.type == "tool_result"}

tests/server/test_session.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -337,15 +337,15 @@ async def test_create_message_tool_result_validation():
337337
)
338338

339339
# Case 2: tool_result without previous message
340-
with pytest.raises(ValueError, match="not matching any tool_use"):
340+
with pytest.raises(ValueError, match="requires a previous message"):
341341
await session.create_message(
342342
messages=[types.SamplingMessage(role="user", content=tool_result)],
343343
max_tokens=100,
344344
tools=[tool],
345345
)
346346

347347
# Case 3: tool_result without previous tool_use
348-
with pytest.raises(ValueError, match="not matching any tool_use"):
348+
with pytest.raises(ValueError, match="do not match any tool_use"):
349349
await session.create_message(
350350
messages=[
351351
types.SamplingMessage(role="user", content=text),
@@ -371,8 +371,10 @@ async def test_create_message_tool_result_validation():
371371
)
372372

373373
# Case 5: text-only message with tools (no tool_results) - passes validation
374-
# This covers branch 261->266 (has_tool_results=False) and 266->272
375-
# We use move_on_after since send_request will block waiting for response
374+
# Covers has_tool_results=False branch.
375+
# We use move_on_after because validation happens synchronously before
376+
# send_request, which would block indefinitely waiting for a response.
377+
# The timeout lets validation pass, then cancels the blocked send.
376378
with anyio.move_on_after(0.01):
377379
await session.create_message(
378380
messages=[types.SamplingMessage(role="user", content=text)],
@@ -381,7 +383,8 @@ async def test_create_message_tool_result_validation():
381383
)
382384

383385
# Case 6: valid matching tool_result/tool_use IDs - passes validation
384-
# This covers branch 269->272 (IDs match, no error raised)
386+
# Covers tool_use_ids == tool_result_ids branch.
387+
# (see Case 5 comment for move_on_after explanation)
385388
with anyio.move_on_after(0.01):
386389
await session.create_message(
387390
messages=[
@@ -393,6 +396,18 @@ async def test_create_message_tool_result_validation():
393396
tools=[tool],
394397
)
395398

399+
# Case 7: validation runs even without `tools` parameter
400+
# (tool loop continuation may omit tools while containing tool_result)
401+
with pytest.raises(ValueError, match="do not match any tool_use"):
402+
await session.create_message(
403+
messages=[
404+
types.SamplingMessage(role="user", content=text),
405+
types.SamplingMessage(role="user", content=tool_result),
406+
],
407+
max_tokens=100,
408+
# Note: no tools parameter
409+
)
410+
396411

397412
@pytest.mark.anyio
398413
async def test_create_message_without_tools_capability():

0 commit comments

Comments
 (0)