Skip to content

SEP: Elicitation Enum Schema Improvements and Standards Compliance #1246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/mcp/server/elicitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import types
from collections.abc import Sequence
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin

from pydantic import BaseModel
Expand Down Expand Up @@ -46,11 +47,22 @@ def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
if not _is_primitive_field(field_info):
raise TypeError(
f"Elicitation schema field '{field_name}' must be a primitive type "
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
f"Complex types like lists, dicts, or nested models are not allowed."
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
f"or Optional of these types. Nested models and complex types are not allowed."
)


def _is_string_sequence(annotation: type) -> bool:
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
origin = get_origin(annotation)
# Check if it's a sequence-like type with str elements
if origin and issubclass(origin, Sequence):
args = get_args(annotation)
# Should have single str type arg
return len(args) == 1 and args[0] is str
return False


def _is_primitive_field(field_info: FieldInfo) -> bool:
"""Check if a field is a primitive type allowed in elicitation schemas."""
annotation = field_info.annotation
Expand All @@ -63,12 +75,21 @@ def _is_primitive_field(field_info: FieldInfo) -> bool:
if annotation in _ELICITATION_PRIMITIVE_TYPES:
return True

# Handle string sequences for multi-select enums
if annotation is not None and _is_string_sequence(annotation):
return True

# Handle Union types
origin = get_origin(annotation)
if origin is Union or origin is types.UnionType:
args = get_args(annotation)
# All args must be primitive types or None
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
# All args must be primitive types, None, or string sequences
return all(
arg is types.NoneType
or arg in _ELICITATION_PRIMITIVE_TYPES
or (arg is not None and _is_string_sequence(arg))
for arg in args
)

return False

Expand Down
2 changes: 1 addition & 1 deletion src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ class ElicitResult(Result):
- "cancel": User dismissed without making an explicit choice
"""

content: dict[str, str | int | float | bool | None] | None = None
content: dict[str, str | int | float | bool | list[str] | None] | None = None
"""
The submitted form data, only present when action is "accept".
Contains values matching the requested schema.
Expand Down
111 changes: 108 additions & 3 deletions tests/server/fastmcp/test_elicitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def tool(ctx: Context) -> str:

# Test cases for invalid schemas
class InvalidListSchema(BaseModel):
names: list[str] = Field(description="List of names")
numbers: list[int] = Field(description="List of numbers")

class NestedModel(BaseModel):
value: str
Expand All @@ -133,7 +133,7 @@ async def elicitation_callback(context, params):
await client_session.initialize()

# Test both invalid schemas
for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]:
for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]:
result = await client_session.call_tool(tool_name, {})
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
Expand Down Expand Up @@ -191,7 +191,7 @@ async def callback(context, params):
# Test invalid optional field
class InvalidOptionalSchema(BaseModel):
name: str = Field(description="Name")
optional_list: list[str] | None = Field(default=None, description="Invalid optional list")
optional_list: list[int] | None = Field(default=None, description="Invalid optional list")

@mcp.tool(description="Tool with invalid optional field")
async def invalid_optional_tool(ctx: Context) -> str:
Expand All @@ -208,3 +208,108 @@ async def invalid_optional_tool(ctx: Context) -> str:
{},
text_contains=["Validation failed:", "optional_list"],
)

# Test valid list[str] for multi-select enum
class ValidMultiSelectSchema(BaseModel):
name: str = Field(description="Name")
tags: list[str] = Field(description="Tags")

@mcp.tool(description="Tool with valid list[str] field")
async def valid_multiselect_tool(ctx: Context) -> str:
result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema)
if result.action == "accept" and result.data:
return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}"
return f"User {result.action}"

async def multiselect_callback(context, params):
if "Please provide tags" in params.message:
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
return ElicitResult(action="decline")

await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2")


@pytest.mark.anyio
async def test_elicitation_with_enum_titles():
"""Test elicitation with enum schemas using oneOf/anyOf for titles."""
mcp = FastMCP(name="ColorPreferencesApp")

# Test single-select with titles using oneOf
class FavoriteColorSchema(BaseModel):
user_name: str = Field(description="Your name")
favorite_color: str = Field(
description="Select your favorite color",
json_schema_extra={
"oneOf": [
{"const": "red", "title": "Red"},
{"const": "green", "title": "Green"},
{"const": "blue", "title": "Blue"},
{"const": "yellow", "title": "Yellow"},
]
},
)

@mcp.tool(description="Single color selection")
async def select_favorite_color(ctx: Context) -> str:
result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema)
if result.action == "accept" and result.data:
return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}"
return f"User {result.action}"

# Test multi-select with titles using anyOf
class FavoriteColorsSchema(BaseModel):
user_name: str = Field(description="Your name")
favorite_colors: list[str] = Field(
description="Select your favorite colors",
json_schema_extra={
"items": {
"anyOf": [
{"const": "red", "title": "Red"},
{"const": "green", "title": "Green"},
{"const": "blue", "title": "Blue"},
{"const": "yellow", "title": "Yellow"},
]
}
},
)

@mcp.tool(description="Multiple color selection")
async def select_favorite_colors(ctx: Context) -> str:
result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema)
if result.action == "accept" and result.data:
return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}"
return f"User {result.action}"

# Test deprecated enumNames format
class DeprecatedColorSchema(BaseModel):
user_name: str = Field(description="Your name")
color: str = Field(
description="Select a color",
json_schema_extra={"enum": ["red", "green", "blue"], "enumNames": ["Red", "Green", "Blue"]},
)

@mcp.tool(description="Deprecated enum format")
async def select_color_deprecated(ctx: Context) -> str:
result = await ctx.elicit(message="Select a color (deprecated format)", schema=DeprecatedColorSchema)
if result.action == "accept" and result.data:
return f"User: {result.data.user_name}, Color: {result.data.color}"
return f"User {result.action}"

async def enum_callback(context, params):
if "colors" in params.message and "deprecated" not in params.message:
return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]})
elif "color" in params.message:
if "deprecated" in params.message:
return ElicitResult(action="accept", content={"user_name": "Charlie", "color": "green"})
else:
return ElicitResult(action="accept", content={"user_name": "Alice", "favorite_color": "blue"})
return ElicitResult(action="decline")

# Test single-select with titles
await call_tool_and_assert(mcp, enum_callback, "select_favorite_color", {}, "User: Alice, Favorite: blue")

# Test multi-select with titles
await call_tool_and_assert(mcp, enum_callback, "select_favorite_colors", {}, "User: Bob, Colors: red, green")

# Test deprecated enumNames format
await call_tool_and_assert(mcp, enum_callback, "select_color_deprecated", {}, "User: Charlie, Color: green")
Loading