Skip to content

Commit 6ddfee7

Browse files
fix: handle Optional[list[str]] in elicitation validation
Fix bug where Union types containing string sequences (like list[str] | None) were incorrectly rejected by elicitation schema validation. Changes: - Updated _is_primitive_field() to check for string sequences in Union types - Added try-except wrapper to _is_string_sequence() to handle non-class origins - Added test case for Optional[list[str]] validation This ensures that optional multi-select enum fields work correctly with the SEP-1330 implementation.
1 parent 84cfad1 commit 6ddfee7

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

src/mcp/server/elicitation.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,15 @@ def _is_string_sequence(annotation: type) -> bool:
6363
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
6464
origin = get_origin(annotation)
6565
# Check if it's a sequence-like type with str elements
66-
if origin and issubclass(origin, Sequence):
67-
args = get_args(annotation)
68-
# Should have single str type arg
69-
return len(args) == 1 and args[0] is str
66+
if origin:
67+
try:
68+
if issubclass(origin, Sequence):
69+
args = get_args(annotation)
70+
# Should have single str type arg
71+
return len(args) == 1 and args[0] is str
72+
except TypeError:
73+
# origin is not a class, so it can't be a subclass of Sequence
74+
pass
7075
return False
7176

7277

@@ -80,8 +85,10 @@ def _is_primitive_field(annotation: type) -> bool:
8085
origin = get_origin(annotation)
8186
if origin is Union or origin is types.UnionType:
8287
args = get_args(annotation)
83-
# All args must be primitive types or None
84-
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
88+
# All args must be primitive types, None, or string sequences
89+
return all(
90+
arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES or _is_string_sequence(arg) for arg in args
91+
)
8592

8693
return False
8794

tests/server/fastmcp/test_elicitation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,28 @@ async def multiselect_callback(context: RequestContext[ClientSession, Any], para
233233

234234
await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2")
235235

236+
# Test Optional[list[str]] for optional multi-select enum
237+
class OptionalMultiSelectSchema(BaseModel):
238+
name: str = Field(description="Name")
239+
tags: list[str] | None = Field(default=None, description="Optional tags")
240+
241+
@mcp.tool(description="Tool with optional list[str] field")
242+
async def optional_multiselect_tool(ctx: Context[ServerSession, None]) -> str:
243+
result = await ctx.elicit(message="Please provide optional tags", schema=OptionalMultiSelectSchema)
244+
if result.action == "accept" and result.data:
245+
tags_str = ", ".join(result.data.tags) if result.data.tags else "none"
246+
return f"Name: {result.data.name}, Tags: {tags_str}"
247+
return f"User {result.action}"
248+
249+
async def optional_multiselect_callback(context: RequestContext[ClientSession, Any], params: ElicitRequestParams):
250+
if "Please provide optional tags" in params.message:
251+
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
252+
return ElicitResult(action="decline")
253+
254+
await call_tool_and_assert(
255+
mcp, optional_multiselect_callback, "optional_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2"
256+
)
257+
236258

237259
@pytest.mark.anyio
238260
async def test_elicitation_with_default_values():

0 commit comments

Comments
 (0)