Skip to content

Commit 4d77773

Browse files
authored
Merge branch 'main' into client_notif_support
2 parents e851e47 + b19fa6f commit 4d77773

File tree

4 files changed

+222
-16
lines changed

4 files changed

+222
-16
lines changed

examples/servers/everything-server/mcp_everything_server/server.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,65 @@ async def test_elicitation_sep1034_defaults(ctx: Context[ServerSession, None]) -
198198
return f"Elicitation not supported or error: {str(e)}"
199199

200200

201+
class EnumSchemasTestSchema(BaseModel):
202+
"""Schema for testing enum schema variations (SEP-1330)"""
203+
204+
untitledSingle: str = Field(
205+
description="Simple enum without titles", json_schema_extra={"enum": ["active", "inactive", "pending"]}
206+
)
207+
titledSingle: str = Field(
208+
description="Enum with titled options (oneOf)",
209+
json_schema_extra={
210+
"oneOf": [
211+
{"const": "low", "title": "Low Priority"},
212+
{"const": "medium", "title": "Medium Priority"},
213+
{"const": "high", "title": "High Priority"},
214+
]
215+
},
216+
)
217+
untitledMulti: list[str] = Field(
218+
description="Multi-select without titles",
219+
json_schema_extra={"items": {"type": "string", "enum": ["read", "write", "execute"]}},
220+
)
221+
titledMulti: list[str] = Field(
222+
description="Multi-select with titled options",
223+
json_schema_extra={
224+
"items": {
225+
"anyOf": [
226+
{"const": "feature", "title": "New Feature"},
227+
{"const": "bug", "title": "Bug Fix"},
228+
{"const": "docs", "title": "Documentation"},
229+
]
230+
}
231+
},
232+
)
233+
legacyEnum: str = Field(
234+
description="Legacy enum with enumNames",
235+
json_schema_extra={
236+
"enum": ["small", "medium", "large"],
237+
"enumNames": ["Small Size", "Medium Size", "Large Size"],
238+
},
239+
)
240+
241+
242+
@mcp.tool()
243+
async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> str:
244+
"""Tests elicitation with enum schema variations per SEP-1330"""
245+
try:
246+
result = await ctx.elicit(
247+
message="Please select values using different enum schema types", schema=EnumSchemasTestSchema
248+
)
249+
250+
if result.action == "accept":
251+
content = result.data.model_dump_json()
252+
else:
253+
content = "{}"
254+
255+
return f"Elicitation completed: action={result.action}, content={content}"
256+
except Exception as e:
257+
return f"Elicitation not supported or error: {str(e)}"
258+
259+
201260
@mcp.tool()
202261
def test_error_handling() -> str:
203262
"""Tests error response handling"""

src/mcp/server/elicitation.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from __future__ import annotations
44

55
import types
6+
from collections.abc import Sequence
67
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
78

89
from pydantic import BaseModel
9-
from pydantic.fields import FieldInfo
1010

1111
from mcp.server.session import ServerSession
1212
from mcp.types import RequestId
@@ -43,22 +43,40 @@ class CancelledElicitation(BaseModel):
4343
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
4444
"""Validate that a Pydantic model only contains primitive field types."""
4545
for field_name, field_info in schema.model_fields.items():
46-
if not _is_primitive_field(field_info):
46+
annotation = field_info.annotation
47+
48+
if annotation is None or annotation is types.NoneType: # pragma: no cover
49+
continue
50+
elif _is_primitive_field(annotation):
51+
continue
52+
elif _is_string_sequence(annotation):
53+
continue
54+
else:
4755
raise TypeError(
4856
f"Elicitation schema field '{field_name}' must be a primitive type "
49-
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
50-
f"Complex types like lists, dicts, or nested models are not allowed."
57+
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
58+
f"or Optional of these types. Nested models and complex types are not allowed."
5159
)
5260

5361

54-
def _is_primitive_field(field_info: FieldInfo) -> bool:
55-
"""Check if a field is a primitive type allowed in elicitation schemas."""
56-
annotation = field_info.annotation
62+
def _is_string_sequence(annotation: type) -> bool:
63+
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
64+
origin = get_origin(annotation)
65+
# Check if it's a sequence-like type with str elements
66+
if origin:
67+
try:
68+
if issubclass(origin, Sequence):
69+
args = get_args(annotation)
70+
# Should have single str type arg
71+
return len(args) == 1 and args[0] is str
72+
except TypeError: # pragma: no cover
73+
# origin is not a class, so it can't be a subclass of Sequence
74+
pass
75+
return False
5776

58-
# Handle None type
59-
if annotation is types.NoneType: # pragma: no cover
60-
return True
6177

78+
def _is_primitive_field(annotation: type) -> bool:
79+
"""Check if a field is a primitive type allowed in elicitation schemas."""
6280
# Handle basic primitive types
6381
if annotation in _ELICITATION_PRIMITIVE_TYPES:
6482
return True
@@ -67,8 +85,10 @@ def _is_primitive_field(field_info: FieldInfo) -> bool:
6785
origin = get_origin(annotation)
6886
if origin is Union or origin is types.UnionType:
6987
args = get_args(annotation)
70-
# All args must be primitive types or None
71-
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
88+
# All args must be primitive types, None, or string sequences
89+
return all(
90+
arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES or _is_string_sequence(arg) for arg in args
91+
)
7292

7393
return False
7494

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1468,7 +1468,7 @@ class ElicitResult(Result):
14681468
- "cancel": User dismissed without making an explicit choice
14691469
"""
14701470

1471-
content: dict[str, str | int | float | bool | None] | None = None
1471+
content: dict[str, str | int | float | bool | list[str] | None] | None = None
14721472
"""
14731473
The submitted form data, only present when action is "accept".
14741474
Contains values matching the requested schema.

tests/server/fastmcp/test_elicitation.py

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover
116116

117117
# Test cases for invalid schemas
118118
class InvalidListSchema(BaseModel):
119-
names: list[str] = Field(description="List of names")
119+
numbers: list[int] = Field(description="List of numbers")
120120

121121
class NestedModel(BaseModel):
122122
value: str
@@ -139,7 +139,7 @@ async def elicitation_callback(
139139
await client_session.initialize()
140140

141141
# Test both invalid schemas
142-
for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]:
142+
for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]:
143143
result = await client_session.call_tool(tool_name, {})
144144
assert len(result.content) == 1
145145
assert isinstance(result.content[0], TextContent)
@@ -197,7 +197,7 @@ async def callback(context: RequestContext[ClientSession, None], params: ElicitR
197197
# Test invalid optional field
198198
class InvalidOptionalSchema(BaseModel):
199199
name: str = Field(description="Name")
200-
optional_list: list[str] | None = Field(default=None, description="Invalid optional list")
200+
optional_list: list[int] | None = Field(default=None, description="Invalid optional list")
201201

202202
@mcp.tool(description="Tool with invalid optional field")
203203
async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover
@@ -220,6 +220,47 @@ async def elicitation_callback(
220220
text_contains=["Validation failed:", "optional_list"],
221221
)
222222

223+
# Test valid list[str] for multi-select enum
224+
class ValidMultiSelectSchema(BaseModel):
225+
name: str = Field(description="Name")
226+
tags: list[str] = Field(description="Tags")
227+
228+
@mcp.tool(description="Tool with valid list[str] field")
229+
async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str:
230+
result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema)
231+
if result.action == "accept" and result.data:
232+
return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}"
233+
return f"User {result.action}" # pragma: no cover
234+
235+
async def multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams):
236+
if "Please provide tags" in params.message:
237+
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
238+
return ElicitResult(action="decline") # pragma: no cover
239+
240+
await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2")
241+
242+
# Test Optional[list[str]] for optional multi-select enum
243+
class OptionalMultiSelectSchema(BaseModel):
244+
name: str = Field(description="Name")
245+
tags: list[str] | None = Field(default=None, description="Optional tags")
246+
247+
@mcp.tool(description="Tool with optional list[str] field")
248+
async def optional_multiselect_tool(ctx: Context[ServerSession, None]) -> str:
249+
result = await ctx.elicit(message="Please provide optional tags", schema=OptionalMultiSelectSchema)
250+
if result.action == "accept" and result.data:
251+
tags_str = ", ".join(result.data.tags) if result.data.tags else "none"
252+
return f"Name: {result.data.name}, Tags: {tags_str}"
253+
return f"User {result.action}" # pragma: no cover
254+
255+
async def optional_multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams):
256+
if "Please provide optional tags" in params.message:
257+
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
258+
return ElicitResult(action="decline") # pragma: no cover
259+
260+
await call_tool_and_assert(
261+
mcp, optional_multiselect_callback, "optional_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2"
262+
)
263+
223264

224265
@pytest.mark.anyio
225266
async def test_elicitation_with_default_values():
@@ -274,3 +315,89 @@ async def callback_override(context: RequestContext[ClientSession, None], params
274315
await call_tool_and_assert(
275316
mcp, callback_override, "defaults_tool", {}, "Name: John, Age: 25, Subscribe: False, Email: [email protected]"
276317
)
318+
319+
320+
@pytest.mark.anyio
321+
async def test_elicitation_with_enum_titles():
322+
"""Test elicitation with enum schemas using oneOf/anyOf for titles."""
323+
mcp = FastMCP(name="ColorPreferencesApp")
324+
325+
# Test single-select with titles using oneOf
326+
class FavoriteColorSchema(BaseModel):
327+
user_name: str = Field(description="Your name")
328+
favorite_color: str = Field(
329+
description="Select your favorite color",
330+
json_schema_extra={
331+
"oneOf": [
332+
{"const": "red", "title": "Red"},
333+
{"const": "green", "title": "Green"},
334+
{"const": "blue", "title": "Blue"},
335+
{"const": "yellow", "title": "Yellow"},
336+
]
337+
},
338+
)
339+
340+
@mcp.tool(description="Single color selection")
341+
async def select_favorite_color(ctx: Context[ServerSession, None]) -> str:
342+
result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema)
343+
if result.action == "accept" and result.data:
344+
return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}"
345+
return f"User {result.action}" # pragma: no cover
346+
347+
# Test multi-select with titles using anyOf
348+
class FavoriteColorsSchema(BaseModel):
349+
user_name: str = Field(description="Your name")
350+
favorite_colors: list[str] = Field(
351+
description="Select your favorite colors",
352+
json_schema_extra={
353+
"items": {
354+
"anyOf": [
355+
{"const": "red", "title": "Red"},
356+
{"const": "green", "title": "Green"},
357+
{"const": "blue", "title": "Blue"},
358+
{"const": "yellow", "title": "Yellow"},
359+
]
360+
}
361+
},
362+
)
363+
364+
@mcp.tool(description="Multiple color selection")
365+
async def select_favorite_colors(ctx: Context[ServerSession, None]) -> str:
366+
result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema)
367+
if result.action == "accept" and result.data:
368+
return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}"
369+
return f"User {result.action}" # pragma: no cover
370+
371+
# Test legacy enumNames format
372+
class LegacyColorSchema(BaseModel):
373+
user_name: str = Field(description="Your name")
374+
color: str = Field(
375+
description="Select a color",
376+
json_schema_extra={"enum": ["red", "green", "blue"], "enumNames": ["Red", "Green", "Blue"]},
377+
)
378+
379+
@mcp.tool(description="Legacy enum format")
380+
async def select_color_legacy(ctx: Context[ServerSession, None]) -> str:
381+
result = await ctx.elicit(message="Select a color (legacy format)", schema=LegacyColorSchema)
382+
if result.action == "accept" and result.data:
383+
return f"User: {result.data.user_name}, Color: {result.data.color}"
384+
return f"User {result.action}" # pragma: no cover
385+
386+
async def enum_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams):
387+
if "colors" in params.message and "legacy" not in params.message:
388+
return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]})
389+
elif "color" in params.message:
390+
if "legacy" in params.message:
391+
return ElicitResult(action="accept", content={"user_name": "Charlie", "color": "green"})
392+
else:
393+
return ElicitResult(action="accept", content={"user_name": "Alice", "favorite_color": "blue"})
394+
return ElicitResult(action="decline") # pragma: no cover
395+
396+
# Test single-select with titles
397+
await call_tool_and_assert(mcp, enum_callback, "select_favorite_color", {}, "User: Alice, Favorite: blue")
398+
399+
# Test multi-select with titles
400+
await call_tool_and_assert(mcp, enum_callback, "select_favorite_colors", {}, "User: Bob, Colors: red, green")
401+
402+
# Test legacy enumNames format
403+
await call_tool_and_assert(mcp, enum_callback, "select_color_legacy", {}, "User: Charlie, Color: green")

0 commit comments

Comments
 (0)