|
10 | 10 | from mcp.types import ElicitResult, TextContent |
11 | 11 |
|
12 | 12 |
|
13 | | -@pytest.mark.anyio |
14 | | -async def test_stdio_elicitation(): |
15 | | - """Test the elicitation feature using stdio transport.""" |
| 13 | +# Shared schema for basic tests |
| 14 | +class AnswerSchema(BaseModel): |
| 15 | + answer: str = Field(description="The user's answer to the question") |
16 | 16 |
|
17 | | - # Create a FastMCP server with a tool that uses elicitation |
18 | | - mcp = FastMCP(name="StdioElicitationServer") |
| 17 | + |
| 18 | +def create_ask_user_tool(mcp: FastMCP): |
| 19 | + """Create a standard ask_user tool that handles all elicitation responses.""" |
19 | 20 |
|
20 | 21 | @mcp.tool(description="A tool that uses elicitation") |
21 | 22 | async def ask_user(prompt: str, ctx: Context) -> str: |
22 | | - class AnswerSchema(BaseModel): |
23 | | - answer: str = Field(description="The user's answer to the question") |
24 | | - |
25 | 23 | result = await ctx.elicit( |
26 | 24 | message=f"Tool wants to ask: {prompt}", |
27 | 25 | schema=AnswerSchema, |
28 | 26 | ) |
29 | | - |
| 27 | + |
30 | 28 | if result.action == "accept" and result.data: |
31 | 29 | return f"User answered: {result.data.answer}" |
32 | 30 | elif result.action == "decline": |
33 | 31 | return "User declined to answer" |
34 | 32 | else: |
35 | 33 | return "User cancelled" |
36 | 34 |
|
| 35 | + return ask_user |
| 36 | + |
| 37 | + |
| 38 | +async def call_tool_and_assert( |
| 39 | + mcp: FastMCP, |
| 40 | + elicitation_callback, |
| 41 | + tool_name: str, |
| 42 | + args: dict, |
| 43 | + expected_text: str | None = None, |
| 44 | + text_contains: list[str] | None = None, |
| 45 | +): |
| 46 | + """Helper to create session, call tool, and assert result.""" |
| 47 | + async with create_connected_server_and_client_session( |
| 48 | + mcp._mcp_server, elicitation_callback=elicitation_callback |
| 49 | + ) as client_session: |
| 50 | + await client_session.initialize() |
| 51 | + |
| 52 | + result = await client_session.call_tool(tool_name, args) |
| 53 | + assert len(result.content) == 1 |
| 54 | + assert isinstance(result.content[0], TextContent) |
| 55 | + |
| 56 | + if expected_text is not None: |
| 57 | + assert result.content[0].text == expected_text |
| 58 | + elif text_contains is not None: |
| 59 | + for substring in text_contains: |
| 60 | + assert substring in result.content[0].text |
| 61 | + |
| 62 | + return result |
| 63 | + |
| 64 | + |
| 65 | +@pytest.mark.anyio |
| 66 | +async def test_stdio_elicitation(): |
| 67 | + """Test the elicitation feature using stdio transport.""" |
| 68 | + mcp = FastMCP(name="StdioElicitationServer") |
| 69 | + create_ask_user_tool(mcp) |
| 70 | + |
37 | 71 | # Create a custom handler for elicitation requests |
38 | 72 | async def elicitation_callback(context, params): |
39 | | - # Verify the elicitation parameters |
40 | 73 | if params.message == "Tool wants to ask: What is your name?": |
41 | 74 | return ElicitResult(action="accept", content={"answer": "Test User"}) |
42 | 75 | else: |
43 | 76 | raise ValueError(f"Unexpected elicitation message: {params.message}") |
44 | 77 |
|
45 | | - # Use memory-based session to test with stdio transport |
46 | | - async with create_connected_server_and_client_session( |
47 | | - mcp._mcp_server, elicitation_callback=elicitation_callback |
48 | | - ) as client_session: |
49 | | - # First initialize the session |
50 | | - result = await client_session.initialize() |
51 | | - assert result.serverInfo.name == "StdioElicitationServer" |
52 | | - |
53 | | - # Call the tool that uses elicitation |
54 | | - tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"}) |
55 | | - |
56 | | - # Verify the result |
57 | | - assert len(tool_result.content) == 1 |
58 | | - assert isinstance(tool_result.content[0], TextContent) |
59 | | - assert tool_result.content[0].text == "User answered: Test User" |
| 78 | + await call_tool_and_assert( |
| 79 | + mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User answered: Test User" |
| 80 | + ) |
60 | 81 |
|
61 | 82 |
|
62 | 83 | @pytest.mark.anyio |
63 | 84 | async def test_stdio_elicitation_decline(): |
64 | 85 | """Test elicitation with user declining.""" |
65 | | - |
66 | 86 | mcp = FastMCP(name="StdioElicitationDeclineServer") |
67 | | - |
68 | | - @mcp.tool(description="A tool that uses elicitation") |
69 | | - async def ask_user(prompt: str, ctx: Context) -> str: |
70 | | - class AnswerSchema(BaseModel): |
71 | | - answer: str = Field(description="The user's answer to the question") |
72 | | - |
73 | | - result = await ctx.elicit( |
74 | | - message=f"Tool wants to ask: {prompt}", |
75 | | - schema=AnswerSchema, |
76 | | - ) |
77 | | - |
78 | | - if result.action == "accept" and result.data: |
79 | | - return f"User answered: {result.data.answer}" |
80 | | - elif result.action == "decline": |
81 | | - return "User declined to answer" |
82 | | - else: |
83 | | - return "User cancelled" |
84 | | - |
85 | | - # Create a custom handler that declines |
| 87 | + create_ask_user_tool(mcp) |
| 88 | + |
86 | 89 | async def elicitation_callback(context, params): |
87 | 90 | return ElicitResult(action="decline") |
88 | | - |
| 91 | + |
| 92 | + await call_tool_and_assert( |
| 93 | + mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User declined to answer" |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +@pytest.mark.anyio |
| 98 | +async def test_elicitation_schema_validation(): |
| 99 | + """Test that elicitation schemas must only contain primitive types.""" |
| 100 | + mcp = FastMCP(name="ValidationTestServer") |
| 101 | + |
| 102 | + def create_validation_tool(name: str, schema_class: type[BaseModel]): |
| 103 | + @mcp.tool(name=name, description=f"Tool testing {name}") |
| 104 | + async def tool(ctx: Context) -> str: |
| 105 | + try: |
| 106 | + await ctx.elicit(message="This should fail validation", schema=schema_class) |
| 107 | + return "Should not reach here" |
| 108 | + except TypeError as e: |
| 109 | + return f"Validation failed as expected: {str(e)}" |
| 110 | + |
| 111 | + return tool |
| 112 | + |
| 113 | + # Test cases for invalid schemas |
| 114 | + class InvalidListSchema(BaseModel): |
| 115 | + names: list[str] = Field(description="List of names") |
| 116 | + |
| 117 | + class NestedModel(BaseModel): |
| 118 | + value: str |
| 119 | + |
| 120 | + class InvalidNestedSchema(BaseModel): |
| 121 | + nested: NestedModel = Field(description="Nested model") |
| 122 | + |
| 123 | + create_validation_tool("invalid_list", InvalidListSchema) |
| 124 | + create_validation_tool("nested_model", InvalidNestedSchema) |
| 125 | + |
| 126 | + # Dummy callback (won't be called due to validation failure) |
| 127 | + async def elicitation_callback(context, params): |
| 128 | + return ElicitResult(action="accept", content={}) |
| 129 | + |
89 | 130 | async with create_connected_server_and_client_session( |
90 | 131 | mcp._mcp_server, elicitation_callback=elicitation_callback |
91 | 132 | ) as client_session: |
92 | | - # Initialize |
93 | 133 | await client_session.initialize() |
94 | | - |
95 | | - # Call the tool |
96 | | - tool_result = await client_session.call_tool("ask_user", {"prompt": "What is your name?"}) |
97 | | - |
98 | | - # Verify the result |
99 | | - assert len(tool_result.content) == 1 |
100 | | - assert isinstance(tool_result.content[0], TextContent) |
101 | | - assert tool_result.content[0].text == "User declined to answer" |
| 134 | + |
| 135 | + # Test both invalid schemas |
| 136 | + for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]: |
| 137 | + result = await client_session.call_tool(tool_name, {}) |
| 138 | + assert len(result.content) == 1 |
| 139 | + assert isinstance(result.content[0], TextContent) |
| 140 | + assert "Validation failed as expected" in result.content[0].text |
| 141 | + assert field_name in result.content[0].text |
| 142 | + |
| 143 | + |
| 144 | +@pytest.mark.anyio |
| 145 | +async def test_elicitation_with_optional_fields(): |
| 146 | + """Test that Optional fields work correctly in elicitation schemas.""" |
| 147 | + mcp = FastMCP(name="OptionalFieldServer") |
| 148 | + |
| 149 | + class OptionalSchema(BaseModel): |
| 150 | + required_name: str = Field(description="Your name (required)") |
| 151 | + optional_age: int | None = Field(default=None, description="Your age (optional)") |
| 152 | + optional_email: str | None = Field(default=None, description="Your email (optional)") |
| 153 | + subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") |
| 154 | + |
| 155 | + @mcp.tool(description="Tool with optional fields") |
| 156 | + async def optional_tool(ctx: Context) -> str: |
| 157 | + result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) |
| 158 | + |
| 159 | + if result.action == "accept" and result.data: |
| 160 | + info = [f"Name: {result.data.required_name}"] |
| 161 | + if result.data.optional_age is not None: |
| 162 | + info.append(f"Age: {result.data.optional_age}") |
| 163 | + if result.data.optional_email is not None: |
| 164 | + info.append(f"Email: {result.data.optional_email}") |
| 165 | + info.append(f"Subscribe: {result.data.subscribe}") |
| 166 | + return ", ".join(info) |
| 167 | + else: |
| 168 | + return f"User {result.action}" |
| 169 | + |
| 170 | + # Test cases with different field combinations |
| 171 | + test_cases = [ |
| 172 | + ( |
| 173 | + # All fields provided |
| 174 | + { "required_name": "John Doe", "optional_age": 30, "optional_email": "[email protected]", "subscribe": True}, |
| 175 | + "Name: John Doe, Age: 30, Email: [email protected], Subscribe: True", |
| 176 | + ), |
| 177 | + ( |
| 178 | + # Only required fields |
| 179 | + {"required_name": "Jane Smith"}, |
| 180 | + "Name: Jane Smith, Subscribe: False", |
| 181 | + ), |
| 182 | + ] |
| 183 | + |
| 184 | + for content, expected in test_cases: |
| 185 | + |
| 186 | + async def callback(context, params): |
| 187 | + return ElicitResult(action="accept", content=content) |
| 188 | + |
| 189 | + await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) |
| 190 | + |
| 191 | + # Test invalid optional field |
| 192 | + class InvalidOptionalSchema(BaseModel): |
| 193 | + name: str = Field(description="Name") |
| 194 | + optional_list: list[str] | None = Field(default=None, description="Invalid optional list") |
| 195 | + |
| 196 | + @mcp.tool(description="Tool with invalid optional field") |
| 197 | + async def invalid_optional_tool(ctx: Context) -> str: |
| 198 | + try: |
| 199 | + await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) |
| 200 | + return "Should not reach here" |
| 201 | + except TypeError as e: |
| 202 | + return f"Validation failed: {str(e)}" |
| 203 | + |
| 204 | + await call_tool_and_assert( |
| 205 | + mcp, |
| 206 | + lambda c, p: ElicitResult(action="accept", content={}), |
| 207 | + "invalid_optional_tool", |
| 208 | + {}, |
| 209 | + text_contains=["Validation failed:", "optional_list"], |
| 210 | + ) |
0 commit comments