Skip to content

Commit c8684d8

Browse files
Tapan Chughfelixweinberger
authored andcommitted
SEP: Elicitation Enum Schema Improvements and Standards Compliance
1 parent da4fce2 commit c8684d8

File tree

3 files changed

+134
-8
lines changed

3 files changed

+134
-8
lines changed

src/mcp/server/elicitation.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import types
6+
from collections.abc import Sequence
67
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
78

89
from pydantic import BaseModel
@@ -46,11 +47,22 @@ def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
4647
if not _is_primitive_field(field_info):
4748
raise TypeError(
4849
f"Elicitation schema field '{field_name}' must be a primitive type "
49-
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
50-
f"Complex types like lists, dicts, or nested models are not allowed."
50+
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
51+
f"or Optional of these types. Nested models and complex types are not allowed."
5152
)
5253

5354

55+
def _is_string_sequence(annotation: type) -> bool:
56+
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
57+
origin = get_origin(annotation)
58+
# Check if it's a sequence-like type with str elements
59+
if origin and issubclass(origin, Sequence):
60+
args = get_args(annotation)
61+
# Should have single str type arg
62+
return len(args) == 1 and args[0] is str
63+
return False
64+
65+
5466
def _is_primitive_field(field_info: FieldInfo) -> bool:
5567
"""Check if a field is a primitive type allowed in elicitation schemas."""
5668
annotation = field_info.annotation
@@ -63,12 +75,21 @@ def _is_primitive_field(field_info: FieldInfo) -> bool:
6375
if annotation in _ELICITATION_PRIMITIVE_TYPES:
6476
return True
6577

78+
# Handle string sequences for multi-select enums
79+
if annotation is not None and _is_string_sequence(annotation):
80+
return True
81+
6682
# Handle Union types
6783
origin = get_origin(annotation)
6884
if origin is Union or origin is types.UnionType:
6985
args = get_args(annotation)
70-
# All args must be primitive types or None
71-
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
86+
# All args must be primitive types, None, or string sequences
87+
return all(
88+
arg is types.NoneType
89+
or arg in _ELICITATION_PRIMITIVE_TYPES
90+
or (arg is not None and _is_string_sequence(arg))
91+
for arg in args
92+
)
7293

7394
return False
7495

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ class ElicitResult(Result):
13041304
- "cancel": User dismissed without making an explicit choice
13051305
"""
13061306

1307-
content: dict[str, str | int | float | bool | None] | None = None
1307+
content: dict[str, str | int | float | bool | list[str] | None] | None = None
13081308
"""
13091309
The submitted form data, only present when action is "accept".
13101310
Contains values matching the requested schema.

tests/server/fastmcp/test_elicitation.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def tool(ctx: Context[ServerSession, None]) -> str:
114114

115115
# Test cases for invalid schemas
116116
class InvalidListSchema(BaseModel):
117-
names: list[str] = Field(description="List of names")
117+
numbers: list[int] = Field(description="List of numbers")
118118

119119
class NestedModel(BaseModel):
120120
value: str
@@ -135,7 +135,7 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par
135135
await client_session.initialize()
136136

137137
# Test both invalid schemas
138-
for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]:
138+
for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]:
139139
result = await client_session.call_tool(tool_name, {})
140140
assert len(result.content) == 1
141141
assert isinstance(result.content[0], TextContent)
@@ -193,7 +193,7 @@ async def callback(context: RequestContext[ClientSession, None], params: ElicitR
193193
# Test invalid optional field
194194
class InvalidOptionalSchema(BaseModel):
195195
name: str = Field(description="Name")
196-
optional_list: list[str] | None = Field(default=None, description="Invalid optional list")
196+
optional_list: list[int] | None = Field(default=None, description="Invalid optional list")
197197

198198
@mcp.tool(description="Tool with invalid optional field")
199199
async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str:
@@ -214,6 +214,25 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par
214214
text_contains=["Validation failed:", "optional_list"],
215215
)
216216

217+
# Test valid list[str] for multi-select enum
218+
class ValidMultiSelectSchema(BaseModel):
219+
name: str = Field(description="Name")
220+
tags: list[str] = Field(description="Tags")
221+
222+
@mcp.tool(description="Tool with valid list[str] field")
223+
async def valid_multiselect_tool(ctx: Context[ServerSession, None]) -> str:
224+
result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema)
225+
if result.action == "accept" and result.data:
226+
return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}"
227+
return f"User {result.action}"
228+
229+
async def multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams):
230+
if "Please provide tags" in params.message:
231+
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
232+
return ElicitResult(action="decline")
233+
234+
await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2")
235+
217236

218237
@pytest.mark.anyio
219238
async def test_elicitation_with_default_values():
@@ -268,3 +287,89 @@ async def callback_override(context: RequestContext[ClientSession, None], params
268287
await call_tool_and_assert(
269288
mcp, callback_override, "defaults_tool", {}, "Name: John, Age: 25, Subscribe: False, Email: [email protected]"
270289
)
290+
291+
292+
@pytest.mark.anyio
293+
async def test_elicitation_with_enum_titles():
294+
"""Test elicitation with enum schemas using oneOf/anyOf for titles."""
295+
mcp = FastMCP(name="ColorPreferencesApp")
296+
297+
# Test single-select with titles using oneOf
298+
class FavoriteColorSchema(BaseModel):
299+
user_name: str = Field(description="Your name")
300+
favorite_color: str = Field(
301+
description="Select your favorite color",
302+
json_schema_extra={
303+
"oneOf": [
304+
{"const": "red", "title": "Red"},
305+
{"const": "green", "title": "Green"},
306+
{"const": "blue", "title": "Blue"},
307+
{"const": "yellow", "title": "Yellow"},
308+
]
309+
},
310+
)
311+
312+
@mcp.tool(description="Single color selection")
313+
async def select_favorite_color(ctx: Context) -> str:
314+
result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema)
315+
if result.action == "accept" and result.data:
316+
return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}"
317+
return f"User {result.action}"
318+
319+
# Test multi-select with titles using anyOf
320+
class FavoriteColorsSchema(BaseModel):
321+
user_name: str = Field(description="Your name")
322+
favorite_colors: list[str] = Field(
323+
description="Select your favorite colors",
324+
json_schema_extra={
325+
"items": {
326+
"anyOf": [
327+
{"const": "red", "title": "Red"},
328+
{"const": "green", "title": "Green"},
329+
{"const": "blue", "title": "Blue"},
330+
{"const": "yellow", "title": "Yellow"},
331+
]
332+
}
333+
},
334+
)
335+
336+
@mcp.tool(description="Multiple color selection")
337+
async def select_favorite_colors(ctx: Context) -> str:
338+
result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema)
339+
if result.action == "accept" and result.data:
340+
return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}"
341+
return f"User {result.action}"
342+
343+
# Test deprecated enumNames format
344+
class DeprecatedColorSchema(BaseModel):
345+
user_name: str = Field(description="Your name")
346+
color: str = Field(
347+
description="Select a color",
348+
json_schema_extra={"enum": ["red", "green", "blue"], "enumNames": ["Red", "Green", "Blue"]},
349+
)
350+
351+
@mcp.tool(description="Deprecated enum format")
352+
async def select_color_deprecated(ctx: Context) -> str:
353+
result = await ctx.elicit(message="Select a color (deprecated format)", schema=DeprecatedColorSchema)
354+
if result.action == "accept" and result.data:
355+
return f"User: {result.data.user_name}, Color: {result.data.color}"
356+
return f"User {result.action}"
357+
358+
async def enum_callback(context, params):
359+
if "colors" in params.message and "deprecated" not in params.message:
360+
return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]})
361+
elif "color" in params.message:
362+
if "deprecated" in params.message:
363+
return ElicitResult(action="accept", content={"user_name": "Charlie", "color": "green"})
364+
else:
365+
return ElicitResult(action="accept", content={"user_name": "Alice", "favorite_color": "blue"})
366+
return ElicitResult(action="decline")
367+
368+
# Test single-select with titles
369+
await call_tool_and_assert(mcp, enum_callback, "select_favorite_color", {}, "User: Alice, Favorite: blue")
370+
371+
# Test multi-select with titles
372+
await call_tool_and_assert(mcp, enum_callback, "select_favorite_colors", {}, "User: Bob, Colors: red, green")
373+
374+
# Test deprecated enumNames format
375+
await call_tool_and_assert(mcp, enum_callback, "select_color_deprecated", {}, "User: Charlie, Color: green")

0 commit comments

Comments
 (0)