Skip to content

Commit 1fb5497

Browse files
committed
add validation for primitive types
1 parent 349701f commit 1fb5497

File tree

3 files changed

+212
-61
lines changed

3 files changed

+212
-61
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,13 +327,13 @@ mcp = FastMCP("Booking System")
327327
async def book_table(date: str, party_size: int, ctx: Context) -> str:
328328
"""Book a table with confirmation"""
329329

330+
# Schema must only contain primitive types (str, int, float, bool)
330331
class ConfirmBooking(BaseModel):
331332
confirm: bool = Field(description="Confirm booking?")
332333
notes: str = Field(default="", description="Special requests")
333334

334335
result = await ctx.elicit(
335-
message=f"Confirm booking for {party_size} on {date}?",
336-
schema=ConfirmBooking
336+
message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking
337337
)
338338

339339
if result.action == "accept" and result.data:

src/mcp/server/fastmcp/server.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44

55
import inspect
66
import re
7+
import types
78
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
89
from contextlib import (
910
AbstractAsyncContextManager,
1011
asynccontextmanager,
1112
)
1213
from itertools import chain
13-
from typing import Any, Generic, Literal, TypeVar
14+
from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin
1415

1516
import anyio
1617
import pydantic_core
1718
from pydantic import BaseModel, Field, ValidationError
19+
from pydantic.fields import FieldInfo
1820
from pydantic.networks import AnyUrl
1921
from pydantic_settings import BaseSettings, SettingsConfigDict
2022
from starlette.applications import Starlette
@@ -70,13 +72,13 @@
7072

7173
class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]):
7274
"""Result of an elicitation request."""
73-
75+
7476
action: Literal["accept", "decline", "cancel"]
7577
"""The user's action in response to the elicitation."""
76-
78+
7779
data: ElicitSchemaModelT | None = None
7880
"""The validated data if action is 'accept', None otherwise."""
79-
81+
8082
validation_error: str | None = None
8183
"""Validation error message if data failed to validate."""
8284

@@ -873,6 +875,43 @@ def _convert_to_content(
873875
return [TextContent(type="text", text=result)]
874876

875877

878+
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
879+
"""Validate that a Pydantic model only contains primitive field types."""
880+
for field_name, field_info in schema.model_fields.items():
881+
if not _is_primitive_field(field_info):
882+
raise TypeError(
883+
f"Elicitation schema field '{field_name}' must be a primitive type "
884+
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
885+
f"Complex types like lists, dicts, or nested models are not allowed."
886+
)
887+
888+
889+
# Primitive types allowed in elicitation schemas
890+
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
891+
892+
893+
def _is_primitive_field(field_info: FieldInfo) -> bool:
894+
"""Check if a field is a primitive type allowed in elicitation schemas."""
895+
annotation = field_info.annotation
896+
897+
# Handle None type
898+
if annotation is type(None):
899+
return True
900+
901+
# Handle basic primitive types
902+
if annotation in _ELICITATION_PRIMITIVE_TYPES:
903+
return True
904+
905+
# Handle Union types (including Optional and Python 3.10+ union syntax)
906+
origin = get_origin(annotation)
907+
if origin is Union or (hasattr(types, 'UnionType') and isinstance(annotation, types.UnionType)):
908+
args = get_args(annotation)
909+
# All args must be primitive types or None
910+
return all(arg is type(None) or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
911+
912+
return False
913+
914+
876915
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
877916
"""Context object providing access to MCP capabilities.
878917
@@ -996,6 +1035,9 @@ async def elicit(
9961035
The result.data will only be populated if action is "accept" and validation succeeded.
9971036
"""
9981037

1038+
# Validate that schema only contains primitive types and fail loudly if not
1039+
_validate_elicitation_schema(schema)
1040+
9991041
json_schema = schema.model_json_schema()
10001042

10011043
result = await self.request_context.session.elicit(

tests/server/fastmcp/test_elicitation.py

Lines changed: 164 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,92 +10,201 @@
1010
from mcp.types import ElicitResult, TextContent
1111

1212

13-
@pytest.mark.anyio
14-
async def test_stdio_elicitation():
15-
"""Test the elicitation feature using stdio transport."""
13+
# Shared schema for basic tests
14+
class AnswerSchema(BaseModel):
15+
answer: str = Field(description="The user's answer to the question")
1616

17-
# Create a FastMCP server with a tool that uses elicitation
18-
mcp = FastMCP(name="StdioElicitationServer")
17+
18+
def create_ask_user_tool(mcp: FastMCP):
19+
"""Create a standard ask_user tool that handles all elicitation responses."""
1920

2021
@mcp.tool(description="A tool that uses elicitation")
2122
async def ask_user(prompt: str, ctx: Context) -> str:
22-
class AnswerSchema(BaseModel):
23-
answer: str = Field(description="The user's answer to the question")
24-
2523
result = await ctx.elicit(
2624
message=f"Tool wants to ask: {prompt}",
2725
schema=AnswerSchema,
2826
)
29-
27+
3028
if result.action == "accept" and result.data:
3129
return f"User answered: {result.data.answer}"
3230
elif result.action == "decline":
3331
return "User declined to answer"
3432
else:
3533
return "User cancelled"
3634

35+
return ask_user
36+
37+
38+
async def call_tool_and_assert(
39+
mcp: FastMCP,
40+
elicitation_callback,
41+
tool_name: str,
42+
args: dict,
43+
expected_text: str | None = None,
44+
text_contains: list[str] | None = None,
45+
):
46+
"""Helper to create session, call tool, and assert result."""
47+
async with create_connected_server_and_client_session(
48+
mcp._mcp_server, elicitation_callback=elicitation_callback
49+
) as client_session:
50+
await client_session.initialize()
51+
52+
result = await client_session.call_tool(tool_name, args)
53+
assert len(result.content) == 1
54+
assert isinstance(result.content[0], TextContent)
55+
56+
if expected_text is not None:
57+
assert result.content[0].text == expected_text
58+
elif text_contains is not None:
59+
for substring in text_contains:
60+
assert substring in result.content[0].text
61+
62+
return result
63+
64+
65+
@pytest.mark.anyio
66+
async def test_stdio_elicitation():
67+
"""Test the elicitation feature using stdio transport."""
68+
mcp = FastMCP(name="StdioElicitationServer")
69+
create_ask_user_tool(mcp)
70+
3771
# Create a custom handler for elicitation requests
3872
async def elicitation_callback(context, params):
39-
# Verify the elicitation parameters
4073
if params.message == "Tool wants to ask: What is your name?":
4174
return ElicitResult(action="accept", content={"answer": "Test User"})
4275
else:
4376
raise ValueError(f"Unexpected elicitation message: {params.message}")
4477

45-
# Use memory-based session to test with stdio transport
46-
async with create_connected_server_and_client_session(
47-
mcp._mcp_server, elicitation_callback=elicitation_callback
48-
) as client_session:
49-
# First initialize the session
50-
result = await client_session.initialize()
51-
assert result.serverInfo.name == "StdioElicitationServer"
52-
53-
# Call the tool that uses elicitation
54-
tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"})
55-
56-
# Verify the result
57-
assert len(tool_result.content) == 1
58-
assert isinstance(tool_result.content[0], TextContent)
59-
assert tool_result.content[0].text == "User answered: Test User"
78+
await call_tool_and_assert(
79+
mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User answered: Test User"
80+
)
6081

6182

6283
@pytest.mark.anyio
6384
async def test_stdio_elicitation_decline():
6485
"""Test elicitation with user declining."""
65-
6686
mcp = FastMCP(name="StdioElicitationDeclineServer")
67-
68-
@mcp.tool(description="A tool that uses elicitation")
69-
async def ask_user(prompt: str, ctx: Context) -> str:
70-
class AnswerSchema(BaseModel):
71-
answer: str = Field(description="The user's answer to the question")
72-
73-
result = await ctx.elicit(
74-
message=f"Tool wants to ask: {prompt}",
75-
schema=AnswerSchema,
76-
)
77-
78-
if result.action == "accept" and result.data:
79-
return f"User answered: {result.data.answer}"
80-
elif result.action == "decline":
81-
return "User declined to answer"
82-
else:
83-
return "User cancelled"
84-
85-
# Create a custom handler that declines
87+
create_ask_user_tool(mcp)
88+
8689
async def elicitation_callback(context, params):
8790
return ElicitResult(action="decline")
88-
91+
92+
await call_tool_and_assert(
93+
mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User declined to answer"
94+
)
95+
96+
97+
@pytest.mark.anyio
98+
async def test_elicitation_schema_validation():
99+
"""Test that elicitation schemas must only contain primitive types."""
100+
mcp = FastMCP(name="ValidationTestServer")
101+
102+
def create_validation_tool(name: str, schema_class: type[BaseModel]):
103+
@mcp.tool(name=name, description=f"Tool testing {name}")
104+
async def tool(ctx: Context) -> str:
105+
try:
106+
await ctx.elicit(message="This should fail validation", schema=schema_class)
107+
return "Should not reach here"
108+
except TypeError as e:
109+
return f"Validation failed as expected: {str(e)}"
110+
111+
return tool
112+
113+
# Test cases for invalid schemas
114+
class InvalidListSchema(BaseModel):
115+
names: list[str] = Field(description="List of names")
116+
117+
class NestedModel(BaseModel):
118+
value: str
119+
120+
class InvalidNestedSchema(BaseModel):
121+
nested: NestedModel = Field(description="Nested model")
122+
123+
create_validation_tool("invalid_list", InvalidListSchema)
124+
create_validation_tool("nested_model", InvalidNestedSchema)
125+
126+
# Dummy callback (won't be called due to validation failure)
127+
async def elicitation_callback(context, params):
128+
return ElicitResult(action="accept", content={})
129+
89130
async with create_connected_server_and_client_session(
90131
mcp._mcp_server, elicitation_callback=elicitation_callback
91132
) as client_session:
92-
# Initialize
93133
await client_session.initialize()
94-
95-
# Call the tool
96-
tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"})
97-
98-
# Verify the result
99-
assert len(tool_result.content) == 1
100-
assert isinstance(tool_result.content[0], TextContent)
101-
assert tool_result.content[0].text == "User declined to answer"
134+
135+
# Test both invalid schemas
136+
for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]:
137+
result = await client_session.call_tool(tool_name, {})
138+
assert len(result.content) == 1
139+
assert isinstance(result.content[0], TextContent)
140+
assert "Validation failed as expected" in result.content[0].text
141+
assert field_name in result.content[0].text
142+
143+
144+
@pytest.mark.anyio
145+
async def test_elicitation_with_optional_fields():
146+
"""Test that Optional fields work correctly in elicitation schemas."""
147+
mcp = FastMCP(name="OptionalFieldServer")
148+
149+
class OptionalSchema(BaseModel):
150+
required_name: str = Field(description="Your name (required)")
151+
optional_age: int | None = Field(default=None, description="Your age (optional)")
152+
optional_email: str | None = Field(default=None, description="Your email (optional)")
153+
subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?")
154+
155+
@mcp.tool(description="Tool with optional fields")
156+
async def optional_tool(ctx: Context) -> str:
157+
result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema)
158+
159+
if result.action == "accept" and result.data:
160+
info = [f"Name: {result.data.required_name}"]
161+
if result.data.optional_age is not None:
162+
info.append(f"Age: {result.data.optional_age}")
163+
if result.data.optional_email is not None:
164+
info.append(f"Email: {result.data.optional_email}")
165+
info.append(f"Subscribe: {result.data.subscribe}")
166+
return ", ".join(info)
167+
else:
168+
return f"User {result.action}"
169+
170+
# Test cases with different field combinations
171+
test_cases = [
172+
(
173+
# All fields provided
174+
{"required_name": "John Doe", "optional_age": 30, "optional_email": "[email protected]", "subscribe": True},
175+
"Name: John Doe, Age: 30, Email: [email protected], Subscribe: True",
176+
),
177+
(
178+
# Only required fields
179+
{"required_name": "Jane Smith"},
180+
"Name: Jane Smith, Subscribe: False",
181+
),
182+
]
183+
184+
for content, expected in test_cases:
185+
186+
async def callback(context, params):
187+
return ElicitResult(action="accept", content=content)
188+
189+
await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected)
190+
191+
# Test invalid optional field
192+
class InvalidOptionalSchema(BaseModel):
193+
name: str = Field(description="Name")
194+
optional_list: list[str] | None = Field(default=None, description="Invalid optional list")
195+
196+
@mcp.tool(description="Tool with invalid optional field")
197+
async def invalid_optional_tool(ctx: Context) -> str:
198+
try:
199+
await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema)
200+
return "Should not reach here"
201+
except TypeError as e:
202+
return f"Validation failed: {str(e)}"
203+
204+
await call_tool_and_assert(
205+
mcp,
206+
lambda c, p: ElicitResult(action="accept", content={}),
207+
"invalid_optional_tool",
208+
{},
209+
text_contains=["Validation failed:", "optional_list"],
210+
)

0 commit comments

Comments
 (0)