Skip to content

Commit 3943c45

Browse files
committed
validation options
1 parent ef4e167 commit 3943c45

File tree

3 files changed

+282
-8
lines changed

3 files changed

+282
-8
lines changed

src/mcp/client/session.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import anyio.lowlevel
66
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
77
from jsonschema import SchemaError, ValidationError, validate
8-
from pydantic import AnyUrl, TypeAdapter
8+
from pydantic import AnyUrl, BaseModel, Field, TypeAdapter
99

1010
import mcp.types as types
1111
from mcp.shared.context import RequestContext
@@ -18,6 +18,17 @@
1818
logger = logging.getLogger("client")
1919

2020

21+
class ValidationOptions(BaseModel):
22+
"""Options for controlling validation behavior in MCP client sessions."""
23+
24+
strict_output_validation: bool = Field(
25+
default=True,
26+
description="Whether to raise exceptions when tools don't return structured "
27+
"content as specified by their output schema. When False, validation "
28+
"errors are logged as warnings and execution continues.",
29+
)
30+
31+
2132
class SamplingFnT(Protocol):
2233
async def __call__(
2334
self,
@@ -118,6 +129,7 @@ def __init__(
118129
logging_callback: LoggingFnT | None = None,
119130
message_handler: MessageHandlerFnT | None = None,
120131
client_info: types.Implementation | None = None,
132+
validation_options: ValidationOptions | None = None,
121133
) -> None:
122134
super().__init__(
123135
read_stream,
@@ -133,6 +145,7 @@ def __init__(
133145
self._logging_callback = logging_callback or _default_logging_callback
134146
self._message_handler = message_handler or _default_message_handler
135147
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
148+
self._validation_options = validation_options or ValidationOptions()
136149

137150
async def initialize(self) -> types.InitializeResult:
138151
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -324,13 +337,27 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) -
324337

325338
if output_schema is not None:
326339
if result.structuredContent is None:
327-
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
328-
try:
329-
validate(result.structuredContent, output_schema)
330-
except ValidationError as e:
331-
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
332-
except SchemaError as e:
333-
raise RuntimeError(f"Invalid schema for tool {name}: {e}")
340+
if self._validation_options.strict_output_validation:
341+
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
342+
else:
343+
logger.warning(
344+
f"Tool {name} has an output schema but did not return structured content. "
345+
f"Continuing without structured content validation due to lenient validation mode."
346+
)
347+
else:
348+
try:
349+
validate(result.structuredContent, output_schema)
350+
except ValidationError as e:
351+
if self._validation_options.strict_output_validation:
352+
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
353+
else:
354+
logger.warning(
355+
f"Invalid structured content returned by tool {name}: {e}. "
356+
f"Continuing due to lenient validation mode."
357+
)
358+
except SchemaError as e:
359+
# Schema errors are always raised - they indicate a problem with the schema itself
360+
raise RuntimeError(f"Invalid schema for tool {name}: {e}")
334361

335362
async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
336363
"""Send a prompts/list request."""

src/mcp/shared/memory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ async def create_connected_server_and_client_session(
6161
client_info: types.Implementation | None = None,
6262
raise_exceptions: bool = False,
6363
elicitation_callback: ElicitationFnT | None = None,
64+
validation_options: Any | None = None,
6465
) -> AsyncGenerator[ClientSession, None]:
6566
"""Creates a ClientSession that is connected to a running MCP server."""
6667
async with create_client_server_memory_streams() as (
@@ -92,6 +93,7 @@ async def create_connected_server_and_client_session(
9293
message_handler=message_handler,
9394
client_info=client_info,
9495
elicitation_callback=elicitation_callback,
96+
validation_options=validation_options,
9597
) as client_session:
9698
await client_session.initialize()
9799
yield client_session
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""Tests for client-side validation options."""
2+
3+
import logging
4+
from unittest.mock import AsyncMock, MagicMock
5+
6+
import pytest
7+
8+
from mcp.client.session import ValidationOptions, ClientSession
9+
from mcp.types import Tool, CallToolResult, TextContent
10+
import mcp.types as types
11+
12+
13+
class TestValidationOptions:
14+
"""Test validation options for MCP client sessions."""
15+
16+
@pytest.mark.anyio
17+
async def test_strict_validation_default(self):
18+
"""Test that strict validation is enabled by default."""
19+
# Create a mock client session
20+
read_stream = MagicMock()
21+
write_stream = MagicMock()
22+
23+
client = ClientSession(read_stream, write_stream)
24+
25+
# Set up tool with output schema
26+
client._tool_output_schemas = {
27+
"test_tool": {
28+
"type": "object",
29+
"properties": {"result": {"type": "integer"}},
30+
"required": ["result"],
31+
}
32+
}
33+
34+
# Mock send_request to return a result without structured content
35+
mock_result = CallToolResult(
36+
content=[TextContent(type="text", text="This is unstructured text content")],
37+
structuredContent=None,
38+
isError=False,
39+
)
40+
41+
client.send_request = AsyncMock(return_value=mock_result)
42+
43+
# Should raise by default when structured content is missing
44+
with pytest.raises(RuntimeError) as exc_info:
45+
await client.call_tool("test_tool", {})
46+
assert "has an output schema but did not return structured content" in str(exc_info.value)
47+
48+
@pytest.mark.anyio
49+
async def test_lenient_validation_missing_content(self, caplog):
50+
"""Test lenient validation when structured content is missing."""
51+
# Set logging level to capture warnings
52+
caplog.set_level(logging.WARNING)
53+
54+
# Create client with lenient validation
55+
validation_options = ValidationOptions(strict_output_validation=False)
56+
57+
read_stream = MagicMock()
58+
write_stream = MagicMock()
59+
60+
client = ClientSession(read_stream, write_stream, validation_options=validation_options)
61+
62+
# Set up tool with output schema
63+
client._tool_output_schemas = {
64+
"test_tool": {
65+
"type": "object",
66+
"properties": {"result": {"type": "integer"}},
67+
"required": ["result"],
68+
}
69+
}
70+
71+
# Mock send_request to return a result without structured content
72+
mock_result = CallToolResult(
73+
content=[TextContent(type="text", text="This is unstructured text content")],
74+
structuredContent=None,
75+
isError=False,
76+
)
77+
78+
client.send_request = AsyncMock(return_value=mock_result)
79+
80+
# Should not raise with lenient validation
81+
result = await client.call_tool("test_tool", {})
82+
83+
# Should have logged a warning
84+
assert "has an output schema but did not return structured content" in caplog.text
85+
assert "Continuing without structured content validation" in caplog.text
86+
87+
# Result should still be returned
88+
assert result.isError is False
89+
assert result.structuredContent is None
90+
91+
@pytest.mark.anyio
92+
async def test_lenient_validation_invalid_content(self, caplog):
93+
"""Test lenient validation when structured content is invalid."""
94+
# Set logging level to capture warnings
95+
caplog.set_level(logging.WARNING)
96+
97+
# Create client with lenient validation
98+
validation_options = ValidationOptions(strict_output_validation=False)
99+
100+
read_stream = MagicMock()
101+
write_stream = MagicMock()
102+
103+
client = ClientSession(read_stream, write_stream, validation_options=validation_options)
104+
105+
# Set up tool with output schema
106+
client._tool_output_schemas = {
107+
"test_tool": {
108+
"type": "object",
109+
"properties": {"result": {"type": "integer"}},
110+
"required": ["result"],
111+
}
112+
}
113+
114+
# Mock send_request to return a result with invalid structured content
115+
mock_result = CallToolResult(
116+
content=[TextContent(type="text", text='{"result": "not_an_integer"}')],
117+
structuredContent={"result": "not_an_integer"}, # Invalid: string instead of integer
118+
isError=False,
119+
)
120+
121+
client.send_request = AsyncMock(return_value=mock_result)
122+
123+
# Should not raise with lenient validation
124+
result = await client.call_tool("test_tool", {})
125+
126+
# Should have logged a warning
127+
assert "Invalid structured content returned by tool test_tool" in caplog.text
128+
assert "Continuing due to lenient validation mode" in caplog.text
129+
130+
# Result should still be returned with the invalid content
131+
assert result.isError is False
132+
assert result.structuredContent == {"result": "not_an_integer"}
133+
134+
@pytest.mark.anyio
135+
async def test_strict_validation_with_valid_content(self):
136+
"""Test that valid structured content passes validation."""
137+
read_stream = MagicMock()
138+
write_stream = MagicMock()
139+
140+
client = ClientSession(read_stream, write_stream)
141+
142+
# Set up tool with output schema
143+
client._tool_output_schemas = {
144+
"test_tool": {
145+
"type": "object",
146+
"properties": {"result": {"type": "integer"}},
147+
"required": ["result"],
148+
}
149+
}
150+
151+
# Mock send_request to return a result with valid structured content
152+
mock_result = CallToolResult(
153+
content=[TextContent(type="text", text='{"result": 42}')], structuredContent={"result": 42}, isError=False
154+
)
155+
156+
client.send_request = AsyncMock(return_value=mock_result)
157+
158+
# Should not raise with valid content
159+
result = await client.call_tool("test_tool", {})
160+
assert result.isError is False
161+
assert result.structuredContent == {"result": 42}
162+
163+
@pytest.mark.anyio
164+
async def test_schema_errors_always_raised(self):
165+
"""Test that schema errors are always raised regardless of validation mode."""
166+
# Create client with lenient validation
167+
validation_options = ValidationOptions(strict_output_validation=False)
168+
169+
read_stream = MagicMock()
170+
write_stream = MagicMock()
171+
172+
client = ClientSession(read_stream, write_stream, validation_options=validation_options)
173+
174+
# Set up tool with invalid output schema
175+
client._tool_output_schemas = {
176+
"test_tool": "not a valid schema" # Invalid schema
177+
}
178+
179+
# Mock send_request to return a result with structured content
180+
mock_result = CallToolResult(
181+
content=[TextContent(type="text", text='{"result": 42}')], structuredContent={"result": 42}, isError=False
182+
)
183+
184+
client.send_request = AsyncMock(return_value=mock_result)
185+
186+
# Should still raise for schema errors even in lenient mode
187+
with pytest.raises(RuntimeError) as exc_info:
188+
await client.call_tool("test_tool", {})
189+
assert "Invalid schema for tool test_tool" in str(exc_info.value)
190+
191+
@pytest.mark.anyio
192+
async def test_error_results_not_validated(self):
193+
"""Test that error results are not validated."""
194+
read_stream = MagicMock()
195+
write_stream = MagicMock()
196+
197+
client = ClientSession(read_stream, write_stream)
198+
199+
# Set up tool with output schema
200+
client._tool_output_schemas = {
201+
"test_tool": {
202+
"type": "object",
203+
"properties": {"result": {"type": "integer"}},
204+
"required": ["result"],
205+
}
206+
}
207+
208+
# Mock send_request to return an error result
209+
mock_result = CallToolResult(
210+
content=[TextContent(type="text", text="Tool execution failed")],
211+
structuredContent=None,
212+
isError=True, # Error result
213+
)
214+
215+
client.send_request = AsyncMock(return_value=mock_result)
216+
217+
# Should not validate error results
218+
result = await client.call_tool("test_tool", {})
219+
assert result.isError is True
220+
# No exception should be raised
221+
222+
@pytest.mark.anyio
223+
async def test_tool_without_output_schema(self):
224+
"""Test that tools without output schema don't trigger validation."""
225+
read_stream = MagicMock()
226+
write_stream = MagicMock()
227+
228+
client = ClientSession(read_stream, write_stream)
229+
230+
# Tool has no output schema
231+
client._tool_output_schemas = {"test_tool": None}
232+
233+
# Mock send_request to return a result without structured content
234+
mock_result = CallToolResult(
235+
content=[TextContent(type="text", text="This is unstructured text content")],
236+
structuredContent=None,
237+
isError=False,
238+
)
239+
240+
client.send_request = AsyncMock(return_value=mock_result)
241+
242+
# Should not raise when there's no output schema
243+
result = await client.call_tool("test_tool", {})
244+
assert result.isError is False
245+
assert result.structuredContent is None

0 commit comments

Comments
 (0)