From b469c5759786ad6e929319c5afec2aadf5077018 Mon Sep 17 00:00:00 2001 From: Tapan Chugh Date: Thu, 7 Aug 2025 12:21:38 -0700 Subject: [PATCH] SEP: Elicitation Enum Schema Improvements and Standards Compliance --- src/mcp/server/elicitation.py | 29 +++++- src/mcp/types.py | 2 +- tests/server/fastmcp/test_elicitation.py | 111 ++++++++++++++++++++++- 3 files changed, 134 insertions(+), 8 deletions(-) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 1e48738c8..8708124c9 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -3,6 +3,7 @@ from __future__ import annotations import types +from collections.abc import Sequence from typing import Generic, Literal, TypeVar, Union, get_args, get_origin from pydantic import BaseModel @@ -46,11 +47,22 @@ def _validate_elicitation_schema(schema: type[BaseModel]) -> None: if not _is_primitive_field(field_info): raise TypeError( f"Elicitation schema field '{field_name}' must be a primitive type " - f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " - f"Complex types like lists, dicts, or nested models are not allowed." + f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), " + f"or Optional of these types. Nested models and complex types are not allowed." ) +def _is_string_sequence(annotation: type) -> bool: + """Check if annotation is a sequence of strings (list[str], Sequence[str], etc).""" + origin = get_origin(annotation) + # Check if it's a sequence-like type with str elements + if origin and issubclass(origin, Sequence): + args = get_args(annotation) + # Should have single str type arg + return len(args) == 1 and args[0] is str + return False + + def _is_primitive_field(field_info: FieldInfo) -> bool: """Check if a field is a primitive type allowed in elicitation schemas.""" annotation = field_info.annotation @@ -63,12 +75,21 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: if annotation in _ELICITATION_PRIMITIVE_TYPES: return True + # Handle string sequences for multi-select enums + if annotation is not None and _is_string_sequence(annotation): + return True + # Handle Union types origin = get_origin(annotation) if origin is Union or origin is types.UnionType: args = get_args(annotation) - # All args must be primitive types or None - return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) + # All args must be primitive types, None, or string sequences + return all( + arg is types.NoneType + or arg in _ELICITATION_PRIMITIVE_TYPES + or (arg is not None and _is_string_sequence(arg)) + for arg in args + ) return False diff --git a/src/mcp/types.py b/src/mcp/types.py index 98fefa080..d435b8fe3 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1274,7 +1274,7 @@ class ElicitResult(Result): - "cancel": User dismissed without making an explicit choice """ - content: dict[str, str | int | float | bool | None] | None = None + content: dict[str, str | int | float | bool | list[str] | None] | None = None """ The submitted form data, only present when action is "accept". Contains values matching the requested schema. diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 20937d91d..2ea18c119 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -112,7 +112,7 @@ async def tool(ctx: Context) -> str: # Test cases for invalid schemas class InvalidListSchema(BaseModel): - names: list[str] = Field(description="List of names") + numbers: list[int] = Field(description="List of numbers") class NestedModel(BaseModel): value: str @@ -133,7 +133,7 @@ async def elicitation_callback(context, params): await client_session.initialize() # Test both invalid schemas - for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]: + for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]: result = await client_session.call_tool(tool_name, {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -191,7 +191,7 @@ async def callback(context, params): # Test invalid optional field class InvalidOptionalSchema(BaseModel): name: str = Field(description="Name") - optional_list: list[str] | None = Field(default=None, description="Invalid optional list") + optional_list: list[int] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") async def invalid_optional_tool(ctx: Context) -> str: @@ -208,3 +208,108 @@ async def invalid_optional_tool(ctx: Context) -> str: {}, text_contains=["Validation failed:", "optional_list"], ) + + # Test valid list[str] for multi-select enum + class ValidMultiSelectSchema(BaseModel): + name: str = Field(description="Name") + tags: list[str] = Field(description="Tags") + + @mcp.tool(description="Tool with valid list[str] field") + async def valid_multiselect_tool(ctx: Context) -> str: + result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema) + if result.action == "accept" and result.data: + return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" + return f"User {result.action}" + + async def multiselect_callback(context, params): + if "Please provide tags" in params.message: + return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) + return ElicitResult(action="decline") + + await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2") + + +@pytest.mark.anyio +async def test_elicitation_with_enum_titles(): + """Test elicitation with enum schemas using oneOf/anyOf for titles.""" + mcp = FastMCP(name="ColorPreferencesApp") + + # Test single-select with titles using oneOf + class FavoriteColorSchema(BaseModel): + user_name: str = Field(description="Your name") + favorite_color: str = Field( + description="Select your favorite color", + json_schema_extra={ + "oneOf": [ + {"const": "red", "title": "Red"}, + {"const": "green", "title": "Green"}, + {"const": "blue", "title": "Blue"}, + {"const": "yellow", "title": "Yellow"}, + ] + }, + ) + + @mcp.tool(description="Single color selection") + async def select_favorite_color(ctx: Context) -> str: + result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema) + if result.action == "accept" and result.data: + return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}" + return f"User {result.action}" + + # Test multi-select with titles using anyOf + class FavoriteColorsSchema(BaseModel): + user_name: str = Field(description="Your name") + favorite_colors: list[str] = Field( + description="Select your favorite colors", + json_schema_extra={ + "items": { + "anyOf": [ + {"const": "red", "title": "Red"}, + {"const": "green", "title": "Green"}, + {"const": "blue", "title": "Blue"}, + {"const": "yellow", "title": "Yellow"}, + ] + } + }, + ) + + @mcp.tool(description="Multiple color selection") + async def select_favorite_colors(ctx: Context) -> str: + result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema) + if result.action == "accept" and result.data: + return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}" + return f"User {result.action}" + + # Test deprecated enumNames format + class DeprecatedColorSchema(BaseModel): + user_name: str = Field(description="Your name") + color: str = Field( + description="Select a color", + json_schema_extra={"enum": ["red", "green", "blue"], "enumNames": ["Red", "Green", "Blue"]}, + ) + + @mcp.tool(description="Deprecated enum format") + async def select_color_deprecated(ctx: Context) -> str: + result = await ctx.elicit(message="Select a color (deprecated format)", schema=DeprecatedColorSchema) + if result.action == "accept" and result.data: + return f"User: {result.data.user_name}, Color: {result.data.color}" + return f"User {result.action}" + + async def enum_callback(context, params): + if "colors" in params.message and "deprecated" not in params.message: + return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]}) + elif "color" in params.message: + if "deprecated" in params.message: + return ElicitResult(action="accept", content={"user_name": "Charlie", "color": "green"}) + else: + return ElicitResult(action="accept", content={"user_name": "Alice", "favorite_color": "blue"}) + return ElicitResult(action="decline") + + # Test single-select with titles + await call_tool_and_assert(mcp, enum_callback, "select_favorite_color", {}, "User: Alice, Favorite: blue") + + # Test multi-select with titles + await call_tool_and_assert(mcp, enum_callback, "select_favorite_colors", {}, "User: Bob, Colors: red, green") + + # Test deprecated enumNames format + await call_tool_and_assert(mcp, enum_callback, "select_color_deprecated", {}, "User: Charlie, Color: green")