Skip to content

Commit 953d3eb

Browse files
authored
Fix elicitation enums (#1632)
1 parent 33545ab commit 953d3eb

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

src/fastmcp/server/elicitation.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
DeclinedElicitation,
99
)
1010
from pydantic import BaseModel
11+
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
12+
from pydantic_core import core_schema
1113

1214
from fastmcp.utilities.json_schema import compress_schema
1315
from fastmcp.utilities.logging import get_logger
@@ -26,6 +28,60 @@
2628
T = TypeVar("T")
2729

2830

31+
class ElicitationJsonSchema(GenerateJsonSchema):
32+
"""Custom JSON schema generator for MCP elicitation that always inlines enums.
33+
34+
MCP elicitation requires inline enum schemas without $ref/$defs references.
35+
This generator ensures enums are always generated inline for compatibility.
36+
Optionally adds enumNames for better UI display when available.
37+
"""
38+
39+
def generate_inner(self, schema: core_schema.CoreSchema) -> JsonSchemaValue:
40+
"""Override to prevent ref generation for enums."""
41+
# For enum schemas, bypass the ref mechanism entirely
42+
if schema["type"] == "enum":
43+
# Directly call our custom enum_schema without going through handler
44+
# This prevents the ref/defs mechanism from being invoked
45+
return self.enum_schema(schema)
46+
# For all other types, use the default implementation
47+
return super().generate_inner(schema)
48+
49+
def enum_schema(self, schema: core_schema.EnumSchema) -> JsonSchemaValue:
50+
"""Generate inline enum schema with optional enumNames for better UI.
51+
52+
If enum members have a _display_name_ attribute or custom __str__,
53+
we'll include enumNames for better UI representation.
54+
"""
55+
# Get the base schema from parent
56+
result = super().enum_schema(schema)
57+
58+
# Try to add enumNames if the enum has display-friendly names
59+
enum_cls = schema.get("cls")
60+
if enum_cls:
61+
members = schema.get("members", [])
62+
enum_names = []
63+
has_custom_names = False
64+
65+
for member in members:
66+
# Check if member has a custom display name attribute
67+
if hasattr(member, "_display_name_"):
68+
enum_names.append(member._display_name_)
69+
has_custom_names = True
70+
# Or use the member name with better formatting
71+
else:
72+
# Convert SNAKE_CASE to Title Case for display
73+
display_name = member.name.replace("_", " ").title()
74+
enum_names.append(display_name)
75+
if display_name != member.value:
76+
has_custom_names = True
77+
78+
# Only add enumNames if they differ from the values
79+
if has_custom_names:
80+
result["enumNames"] = enum_names
81+
82+
return result
83+
84+
2985
# we can't use the low-level AcceptedElicitation because it only works with BaseModels
3086
class AcceptedElicitation(BaseModel, Generic[T]):
3187
"""Result when user accepts the elicitation."""
@@ -46,7 +102,10 @@ def get_elicitation_schema(response_type: type[T]) -> dict[str, Any]:
46102
response_type: The type of the response
47103
"""
48104

49-
schema = get_cached_typeadapter(response_type).json_schema()
105+
# Use custom schema generator that inlines enums for MCP compatibility
106+
schema = get_cached_typeadapter(response_type).json_schema(
107+
schema_generator=ElicitationJsonSchema
108+
)
50109
schema = compress_schema(schema)
51110

52111
# Validate the schema to ensure it follows MCP elicitation requirements

tests/client/test_elicitation.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AcceptedElicitation,
1616
CancelledElicitation,
1717
DeclinedElicitation,
18+
get_elicitation_schema,
1819
validate_elicitation_json_schema,
1920
)
2021
from fastmcp.utilities.types import TypeAdapter
@@ -639,3 +640,80 @@ async def elicitation_handler(message, response_type, params, ctx):
639640
match="Elicitation responses must be serializable as a JSON object",
640641
):
641642
await client.call_tool("ask_for_name")
643+
644+
645+
def test_enum_elicitation_schema_inline():
646+
"""Test that enum schemas are generated inline without $ref/$defs for MCP compatibility."""
647+
648+
class Priority(Enum):
649+
LOW = "low"
650+
MEDIUM = "medium"
651+
HIGH = "high"
652+
653+
@dataclass
654+
class TaskRequest:
655+
title: str
656+
priority: Priority
657+
658+
# Generate elicitation schema
659+
schema = get_elicitation_schema(TaskRequest)
660+
661+
# Verify no $defs section exists (enums should be inlined)
662+
assert "$defs" not in schema, (
663+
"Schema should not contain $defs - enums must be inline"
664+
)
665+
666+
# Verify no $ref in properties
667+
for prop_name, prop_schema in schema.get("properties", {}).items():
668+
assert "$ref" not in prop_schema, (
669+
f"Property {prop_name} contains $ref - should be inline"
670+
)
671+
672+
# Verify the priority field has inline enum values
673+
priority_schema = schema["properties"]["priority"]
674+
assert "enum" in priority_schema, "Priority should have enum values inline"
675+
assert priority_schema["enum"] == ["low", "medium", "high"]
676+
assert priority_schema.get("type") == "string"
677+
678+
# Verify title field is a simple string
679+
assert schema["properties"]["title"]["type"] == "string"
680+
681+
682+
def test_enum_elicitation_schema_with_enum_names():
683+
"""Test that enum schemas can include enumNames for better UI display."""
684+
685+
class TaskStatus(Enum):
686+
NOT_STARTED = "not_started"
687+
IN_PROGRESS = "in_progress"
688+
COMPLETED = "completed"
689+
ON_HOLD = "on_hold"
690+
691+
@dataclass
692+
class TaskUpdate:
693+
task_id: str
694+
status: TaskStatus
695+
696+
# Generate elicitation schema
697+
schema = get_elicitation_schema(TaskUpdate)
698+
699+
# Verify enum is inline
700+
assert "$defs" not in schema
701+
assert "$ref" not in str(schema)
702+
703+
status_schema = schema["properties"]["status"]
704+
assert "enum" in status_schema
705+
assert status_schema["enum"] == [
706+
"not_started",
707+
"in_progress",
708+
"completed",
709+
"on_hold",
710+
]
711+
712+
# Check if enumNames were added for display
713+
assert "enumNames" in status_schema
714+
assert status_schema["enumNames"] == [
715+
"Not Started",
716+
"In Progress",
717+
"Completed",
718+
"On Hold",
719+
]

0 commit comments

Comments
 (0)