Skip to content

Add ValidationOptions for lenient output schema validation #1260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from jsonschema import SchemaError, ValidationError, validate
from pydantic import AnyUrl, TypeAdapter
from pydantic import AnyUrl, BaseModel, Field, TypeAdapter

import mcp.types as types
from mcp.shared.context import RequestContext
Expand All @@ -18,6 +18,17 @@
logger = logging.getLogger("client")


class ValidationOptions(BaseModel):
"""Options for controlling validation behavior in MCP client sessions."""

strict_output_validation: bool = Field(
default=True,
description="Whether to raise exceptions when tools don't return structured "
"content as specified by their output schema. When False, validation "
"errors are logged as warnings and execution continues.",
)


class SamplingFnT(Protocol):
async def __call__(
self,
Expand Down Expand Up @@ -118,6 +129,7 @@ def __init__(
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
validation_options: ValidationOptions | None = None,
) -> None:
super().__init__(
read_stream,
Expand All @@ -133,6 +145,7 @@ def __init__(
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._validation_options = validation_options or ValidationOptions()

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
Expand Down Expand Up @@ -324,13 +337,27 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) -

if output_schema is not None:
if result.structuredContent is None:
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
try:
validate(result.structuredContent, output_schema)
except ValidationError as e:
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
except SchemaError as e:
raise RuntimeError(f"Invalid schema for tool {name}: {e}")
if self._validation_options.strict_output_validation:
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
else:
logger.warning(
f"Tool {name} has an output schema but did not return structured content. "
f"Continuing without structured content validation due to lenient validation mode."
)
else:
try:
validate(result.structuredContent, output_schema)
except ValidationError as e:
if self._validation_options.strict_output_validation:
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
else:
logger.warning(
f"Invalid structured content returned by tool {name}: {e}. "
f"Continuing due to lenient validation mode."
)
except SchemaError as e:
# Schema errors are always raised - they indicate a problem with the schema itself
raise RuntimeError(f"Invalid schema for tool {name}: {e}")

async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
"""Send a prompts/list request."""
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def create_connected_server_and_client_session(
client_info: types.Implementation | None = None,
raise_exceptions: bool = False,
elicitation_callback: ElicitationFnT | None = None,
validation_options: Any | None = None,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
Expand Down Expand Up @@ -92,6 +93,7 @@ async def create_connected_server_and_client_session(
message_handler=message_handler,
client_info=client_info,
elicitation_callback=elicitation_callback,
validation_options=validation_options,
) as client_session:
await client_session.initialize()
yield client_session
Expand Down
244 changes: 244 additions & 0 deletions tests/client/test_validation_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
"""Tests for client-side validation options."""

import logging
from unittest.mock import AsyncMock, MagicMock

import pytest

from mcp.client.session import ClientSession, ValidationOptions
from mcp.types import CallToolResult, TextContent


class TestValidationOptions:
"""Test validation options for MCP client sessions."""

@pytest.mark.anyio
async def test_strict_validation_default(self):
"""Test that strict validation is enabled by default."""
# Create a mock client session
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return a result without structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text="This is unstructured text content")],
structuredContent=None,
isError=False,
)

client.send_request = AsyncMock(return_value=mock_result)

# Should raise by default when structured content is missing
with pytest.raises(RuntimeError) as exc_info:
await client.call_tool("test_tool", {})
assert "has an output schema but did not return structured content" in str(exc_info.value)

@pytest.mark.anyio
async def test_lenient_validation_missing_content(self, caplog):
"""Test lenient validation when structured content is missing."""
# Set logging level to capture warnings
caplog.set_level(logging.WARNING)

# Create client with lenient validation
validation_options = ValidationOptions(strict_output_validation=False)

read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream, validation_options=validation_options)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return a result without structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text="This is unstructured text content")],
structuredContent=None,
isError=False,
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not raise with lenient validation
result = await client.call_tool("test_tool", {})

# Should have logged a warning
assert "has an output schema but did not return structured content" in caplog.text
assert "Continuing without structured content validation" in caplog.text

# Result should still be returned
assert result.isError is False
assert result.structuredContent is None

@pytest.mark.anyio
async def test_lenient_validation_invalid_content(self, caplog):
"""Test lenient validation when structured content is invalid."""
# Set logging level to capture warnings
caplog.set_level(logging.WARNING)

# Create client with lenient validation
validation_options = ValidationOptions(strict_output_validation=False)

read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream, validation_options=validation_options)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return a result with invalid structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text='{"result": "not_an_integer"}')],
structuredContent={"result": "not_an_integer"}, # Invalid: string instead of integer
isError=False,
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not raise with lenient validation
result = await client.call_tool("test_tool", {})

# Should have logged a warning
assert "Invalid structured content returned by tool test_tool" in caplog.text
assert "Continuing due to lenient validation mode" in caplog.text

# Result should still be returned with the invalid content
assert result.isError is False
assert result.structuredContent == {"result": "not_an_integer"}

@pytest.mark.anyio
async def test_strict_validation_with_valid_content(self):
"""Test that valid structured content passes validation."""
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return a result with valid structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text='{"result": 42}')], structuredContent={"result": 42}, isError=False
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not raise with valid content
result = await client.call_tool("test_tool", {})
assert result.isError is False
assert result.structuredContent == {"result": 42}

@pytest.mark.anyio
async def test_schema_errors_always_raised(self):
"""Test that schema errors are always raised regardless of validation mode."""
# Create client with lenient validation
validation_options = ValidationOptions(strict_output_validation=False)

read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream, validation_options=validation_options)

# Set up tool with invalid output schema
client._tool_output_schemas = {
"test_tool": "not a valid schema" # type: ignore # Invalid schema for testing
}

# Mock send_request to return a result with structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text='{"result": 42}')], structuredContent={"result": 42}, isError=False
)

client.send_request = AsyncMock(return_value=mock_result)

# Should still raise for schema errors even in lenient mode
with pytest.raises(RuntimeError) as exc_info:
await client.call_tool("test_tool", {})
assert "Invalid schema for tool test_tool" in str(exc_info.value)

@pytest.mark.anyio
async def test_error_results_not_validated(self):
"""Test that error results are not validated."""
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream)

# Set up tool with output schema
client._tool_output_schemas = {
"test_tool": {
"type": "object",
"properties": {"result": {"type": "integer"}},
"required": ["result"],
}
}

# Mock send_request to return an error result
mock_result = CallToolResult(
content=[TextContent(type="text", text="Tool execution failed")],
structuredContent=None,
isError=True, # Error result
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not validate error results
result = await client.call_tool("test_tool", {})
assert result.isError is True
# No exception should be raised

@pytest.mark.anyio
async def test_tool_without_output_schema(self):
"""Test that tools without output schema don't trigger validation."""
read_stream = MagicMock()
write_stream = MagicMock()

client = ClientSession(read_stream, write_stream)

# Tool has no output schema
client._tool_output_schemas = {"test_tool": None}

# Mock send_request to return a result without structured content
mock_result = CallToolResult(
content=[TextContent(type="text", text="This is unstructured text content")],
structuredContent=None,
isError=False,
)

client.send_request = AsyncMock(return_value=mock_result)

# Should not raise when there's no output schema
result = await client.call_tool("test_tool", {})
assert result.isError is False
assert result.structuredContent is None
Loading