Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
validate_structured_outputs: bool = True,
) -> None:
super().__init__(
read_stream,
Expand All @@ -133,6 +134,7 @@ def __init__(
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._validate_structured_outputs = validate_structured_outputs

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
Expand Down Expand Up @@ -324,13 +326,26 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) -

if output_schema is not None:
if result.structuredContent is None:
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
try:
validate(result.structuredContent, output_schema)
except ValidationError as e:
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
except SchemaError as e:
raise RuntimeError(f"Invalid schema for tool {name}: {e}")
if self._validate_structured_outputs:
Comment on lines 327 to +329
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: getting quite heavily nested here, could at least partially simplify with

if output_schema is None:
    return
    
if result.structuredContent is None:
    # raise / log depending
    
# handle the base case of try / except with raise / log depending

raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
else:
logger.warning(
f"Tool {name} has an output schema but did not return structured content. "
f"Continuing without structured content validation."
)
else:
try:
validate(result.structuredContent, output_schema)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error would None produce here, do we need the two branches?

except ValidationError as e:
if self._validate_structured_outputs:
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") from e
else:
logger.warning(
f"Invalid structured content returned by tool {name}: {e}. Continuing without validation."
)
except SchemaError as e:
# Schema errors are always raised - they indicate a problem with the schema itself
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove redundant comment

raise RuntimeError(f"Invalid schema for tool {name}: {e}") from e

async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
"""Send a prompts/list request."""
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/auth/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ async def revoke_token(

def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
parsed_uri = urlparse(redirect_uri_base)
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs]
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intentional? seems unrelated

for k, v in params.items():
if v is not None:
query_params.append((k, v))
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def create_connected_server_and_client_session(
client_info: types.Implementation | None = None,
raise_exceptions: bool = False,
elicitation_callback: ElicitationFnT | None = None,
validate_structured_outputs: bool = True,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
Expand Down Expand Up @@ -92,6 +93,7 @@ async def create_connected_server_and_client_session(
message_handler=message_handler,
client_info=client_info,
elicitation_callback=elicitation_callback,
validate_structured_outputs=validate_structured_outputs,
) as client_session:
await client_session.initialize()
yield client_session
Expand Down
240 changes: 240 additions & 0 deletions tests/client/test_validation_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""Tests for client-side validation options."""

import logging
from unittest.mock import AsyncMock, MagicMock

import pytest

from mcp.client.session import ClientSession
from mcp.types import CallToolResult, TextContent


class TestValidationOptions:
"""Test validation options for MCP client sessions."""

@pytest.mark.anyio
async def test_strict_validation_default(self) -> None:
"""Test that strict validation is enabled by default."""
# Create a mock client session
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return a result without structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text="This is unstructured text content")],
structuredContent=None,
isError=False,
)

client.send_request = AsyncMock(return_value=mock_result)

# Should raise by default when structured content is missing
with pytest.raises(RuntimeError) as exc_info:
await client.call_tool("test_tool", {})
assert "has an output schema but did not return structured content" in str(exc_info.value)

@pytest.mark.anyio
async def test_lenient_validation_missing_content(self, caplog: pytest.LogCaptureFixture) -> None:
"""Test lenient validation when structured content is missing."""
# Set logging level to capture warnings
caplog.set_level(logging.WARNING)

# Create client with lenient validation
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream, validate_structured_outputs=False)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return a result without structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text="This is unstructured text content")],
structuredContent=None,
isError=False,
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not raise with lenient validation
result = await client.call_tool("test_tool", {})

# Should have logged a warning
assert "has an output schema but did not return structured content" in caplog.text
assert "Continuing without structured content validation" in caplog.text

# Result should still be returned
assert result.isError is False
assert result.structuredContent is None

@pytest.mark.anyio
async def test_lenient_validation_invalid_content(self, caplog: pytest.LogCaptureFixture) -> None:
"""Test lenient validation when structured content is invalid."""
# Set logging level to capture warnings
caplog.set_level(logging.WARNING)

# Create client with lenient validation

read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream, validate_structured_outputs=False)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return a result with invalid structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text='{"result": "not_an_integer"}')],
structuredContent={"result": "not_an_integer"}, # Invalid: string instead of integer
isError=False,
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not raise with lenient validation
result = await client.call_tool("test_tool", {})

# Should have logged a warning
assert "Invalid structured content returned by tool test_tool" in caplog.text
assert "Continuing without validation" in caplog.text

# Result should still be returned with the invalid content
assert result.isError is False
assert result.structuredContent == {"result": "not_an_integer"}

@pytest.mark.anyio
async def test_strict_validation_with_valid_content(self) -> None:
"""Test that valid structured content passes validation."""
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return a result with valid structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text='{"result": 42}')], structuredContent={"result": 42}, isError=False
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not raise with valid content
result = await client.call_tool("test_tool", {})
assert result.isError is False
assert result.structuredContent == {"result": 42}

@pytest.mark.anyio
async def test_schema_errors_always_raised(self) -> None:
"""Test that schema errors are always raised regardless of validation mode."""
# Create client with lenient validation

read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream, validate_structured_outputs=False)

# Set up tool with invalid output schema
client._tool_output_schemas = {
"test_tool": "not a valid schema" # type: ignore # Invalid schema for testing
}

# Mock send_request to return a result with structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text='{"result": 42}')], structuredContent={"result": 42}, isError=False
)

client.send_request = AsyncMock(return_value=mock_result)

# Should still raise for schema errors even in lenient mode
with pytest.raises(RuntimeError) as exc_info:
await client.call_tool("test_tool", {})
assert "Invalid schema for tool test_tool" in str(exc_info.value)

@pytest.mark.anyio
async def test_error_results_not_validated(self) -> None:
"""Test that error results are not validated."""
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return an error result
mock_result = CallToolResult(
content=[TextContent(type="text", text="Tool execution failed")],
structuredContent=None,
isError=True, # Error result
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not validate error results
result = await client.call_tool("test_tool", {})
assert result.isError is True
# No exception should be raised

@pytest.mark.anyio
async def test_tool_without_output_schema(self) -> None:
"""Test that tools without output schema don't trigger validation."""
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream)

# Tool has no output schema
client._tool_output_schemas = {"test_tool": None}

# Mock send_request to return a result without structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text="This is unstructured text content")],
structuredContent=None,
isError=False,
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not raise when there's no output schema
result = await client.call_tool("test_tool", {})
assert result.isError is False
assert result.structuredContent is None
Loading