diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 1853ce7c1..26f06261b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -5,7 +5,7 @@ import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from jsonschema import SchemaError, ValidationError, validate -from pydantic import AnyUrl, TypeAdapter +from pydantic import AnyUrl, BaseModel, Field, TypeAdapter import mcp.types as types from mcp.shared.context import RequestContext @@ -18,6 +18,17 @@ logger = logging.getLogger("client") +class ValidationOptions(BaseModel): + """Options for controlling validation behavior in MCP client sessions.""" + + strict_output_validation: bool = Field( + default=True, + description="Whether to raise exceptions when tools don't return structured " + "content as specified by their output schema. When False, validation " + "errors are logged as warnings and execution continues.", + ) + + class SamplingFnT(Protocol): async def __call__( self, @@ -118,6 +129,7 @@ def __init__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + validation_options: ValidationOptions | None = None, ) -> None: super().__init__( read_stream, @@ -133,6 +145,7 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} + self._validation_options = validation_options or ValidationOptions() async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None @@ -324,13 +337,27 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - if output_schema is not None: if result.structuredContent is None: - raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") - try: - validate(result.structuredContent, output_schema) - except ValidationError as e: - raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") - except SchemaError as e: - raise RuntimeError(f"Invalid schema for tool {name}: {e}") + if self._validation_options.strict_output_validation: + raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") + else: + logger.warning( + f"Tool {name} has an output schema but did not return structured content. " + f"Continuing without structured content validation due to lenient validation mode." + ) + else: + try: + validate(result.structuredContent, output_schema) + except ValidationError as e: + if self._validation_options.strict_output_validation: + raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") + else: + logger.warning( + f"Invalid structured content returned by tool {name}: {e}. " + f"Continuing due to lenient validation mode." + ) + except SchemaError as e: + # Schema errors are always raised - they indicate a problem with the schema itself + raise RuntimeError(f"Invalid schema for tool {name}: {e}") async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult: """Send a prompts/list request.""" diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index c94e5e6ac..a89aad500 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -61,6 +61,7 @@ async def create_connected_server_and_client_session( client_info: types.Implementation | None = None, raise_exceptions: bool = False, elicitation_callback: ElicitationFnT | None = None, + validation_options: Any | None = None, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" async with create_client_server_memory_streams() as ( @@ -92,6 +93,7 @@ async def create_connected_server_and_client_session( message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, + validation_options=validation_options, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/client/test_validation_options.py b/tests/client/test_validation_options.py new file mode 100644 index 000000000..a80fd69ae --- /dev/null +++ b/tests/client/test_validation_options.py @@ -0,0 +1,244 @@ +"""Tests for client-side validation options.""" + +import logging +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from mcp.client.session import ClientSession, ValidationOptions +from mcp.types import CallToolResult, TextContent + + +class TestValidationOptions: + """Test validation options for MCP client sessions.""" + + @pytest.mark.anyio + async def test_strict_validation_default(self): + """Test that strict validation is enabled by default.""" + # Create a mock client session + read_stream = MagicMock() + write_stream = MagicMock() + + client = ClientSession(read_stream, write_stream) + + # Set up tool with output schema + client._tool_output_schemas = { + "test_tool": { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "required": ["result"], + } + } + + # Mock send_request to return a result without structured content + mock_result = CallToolResult( + content=[TextContent(type="text", text="This is unstructured text content")], + structuredContent=None, + isError=False, + ) + + client.send_request = AsyncMock(return_value=mock_result) + + # Should raise by default when structured content is missing + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("test_tool", {}) + assert "has an output schema but did not return structured content" in str(exc_info.value) + + @pytest.mark.anyio + async def test_lenient_validation_missing_content(self, caplog): + """Test lenient validation when structured content is missing.""" + # Set logging level to capture warnings + caplog.set_level(logging.WARNING) + + # Create client with lenient validation + validation_options = ValidationOptions(strict_output_validation=False) + + read_stream = MagicMock() + write_stream = MagicMock() + + client = ClientSession(read_stream, write_stream, validation_options=validation_options) + + # Set up tool with output schema + client._tool_output_schemas = { + "test_tool": { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "required": ["result"], + } + } + + # Mock send_request to return a result without structured content + mock_result = CallToolResult( + content=[TextContent(type="text", text="This is unstructured text content")], + structuredContent=None, + isError=False, + ) + + client.send_request = AsyncMock(return_value=mock_result) + + # Should not raise with lenient validation + result = await client.call_tool("test_tool", {}) + + # Should have logged a warning + assert "has an output schema but did not return structured content" in caplog.text + assert "Continuing without structured content validation" in caplog.text + + # Result should still be returned + assert result.isError is False + assert result.structuredContent is None + + @pytest.mark.anyio + async def test_lenient_validation_invalid_content(self, caplog): + """Test lenient validation when structured content is invalid.""" + # Set logging level to capture warnings + caplog.set_level(logging.WARNING) + + # Create client with lenient validation + validation_options = ValidationOptions(strict_output_validation=False) + + read_stream = MagicMock() + write_stream = MagicMock() + + client = ClientSession(read_stream, write_stream, validation_options=validation_options) + + # Set up tool with output schema + client._tool_output_schemas = { + "test_tool": { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "required": ["result"], + } + } + + # Mock send_request to return a result with invalid structured content + mock_result = CallToolResult( + content=[TextContent(type="text", text='{"result": "not_an_integer"}')], + structuredContent={"result": "not_an_integer"}, # Invalid: string instead of integer + isError=False, + ) + + client.send_request = AsyncMock(return_value=mock_result) + + # Should not raise with lenient validation + result = await client.call_tool("test_tool", {}) + + # Should have logged a warning + assert "Invalid structured content returned by tool test_tool" in caplog.text + assert "Continuing due to lenient validation mode" in caplog.text + + # Result should still be returned with the invalid content + assert result.isError is False + assert result.structuredContent == {"result": "not_an_integer"} + + @pytest.mark.anyio + async def test_strict_validation_with_valid_content(self): + """Test that valid structured content passes validation.""" + read_stream = MagicMock() + write_stream = MagicMock() + + client = ClientSession(read_stream, write_stream) + + # Set up tool with output schema + client._tool_output_schemas = { + "test_tool": { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "required": ["result"], + } + } + + # Mock send_request to return a result with valid structured content + mock_result = CallToolResult( + content=[TextContent(type="text", text='{"result": 42}')], structuredContent={"result": 42}, isError=False + ) + + client.send_request = AsyncMock(return_value=mock_result) + + # Should not raise with valid content + result = await client.call_tool("test_tool", {}) + assert result.isError is False + assert result.structuredContent == {"result": 42} + + @pytest.mark.anyio + async def test_schema_errors_always_raised(self): + """Test that schema errors are always raised regardless of validation mode.""" + # Create client with lenient validation + validation_options = ValidationOptions(strict_output_validation=False) + + read_stream = MagicMock() + write_stream = MagicMock() + + client = ClientSession(read_stream, write_stream, validation_options=validation_options) + + # Set up tool with invalid output schema + client._tool_output_schemas = { + "test_tool": "not a valid schema" # type: ignore # Invalid schema for testing + } + + # Mock send_request to return a result with structured content + mock_result = CallToolResult( + content=[TextContent(type="text", text='{"result": 42}')], structuredContent={"result": 42}, isError=False + ) + + client.send_request = AsyncMock(return_value=mock_result) + + # Should still raise for schema errors even in lenient mode + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("test_tool", {}) + assert "Invalid schema for tool test_tool" in str(exc_info.value) + + @pytest.mark.anyio + async def test_error_results_not_validated(self): + """Test that error results are not validated.""" + read_stream = MagicMock() + write_stream = MagicMock() + + client = ClientSession(read_stream, write_stream) + + # Set up tool with output schema + client._tool_output_schemas = { + "test_tool": { + "type": "object", + "properties": {"result": {"type": "integer"}}, + "required": ["result"], + } + } + + # Mock send_request to return an error result + mock_result = CallToolResult( + content=[TextContent(type="text", text="Tool execution failed")], + structuredContent=None, + isError=True, # Error result + ) + + client.send_request = AsyncMock(return_value=mock_result) + + # Should not validate error results + result = await client.call_tool("test_tool", {}) + assert result.isError is True + # No exception should be raised + + @pytest.mark.anyio + async def test_tool_without_output_schema(self): + """Test that tools without output schema don't trigger validation.""" + read_stream = MagicMock() + write_stream = MagicMock() + + client = ClientSession(read_stream, write_stream) + + # Tool has no output schema + client._tool_output_schemas = {"test_tool": None} + + # Mock send_request to return a result without structured content + mock_result = CallToolResult( + content=[TextContent(type="text", text="This is unstructured text content")], + structuredContent=None, + isError=False, + ) + + client.send_request = AsyncMock(return_value=mock_result) + + # Should not raise when there's no output schema + result = await client.call_tool("test_tool", {}) + assert result.isError is False + assert result.structuredContent is None