Skip to content

Commit 3599de8

Browse files
committed
elicitation
1 parent 58c5e72 commit 3599de8

File tree

5 files changed

+191
-6
lines changed

5 files changed

+191
-6
lines changed

src/mcp/client/session.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ async def __call__(
2222
) -> types.CreateMessageResult | types.ErrorData: ...
2323

2424

25+
class ElicitationFnT(Protocol):
26+
async def __call__(
27+
self,
28+
context: RequestContext["ClientSession", Any],
29+
params: types.ElicitRequestParams,
30+
) -> types.ElicitResult | types.ErrorData: ...
31+
32+
2533
class ListRootsFnT(Protocol):
2634
async def __call__(
2735
self, context: RequestContext["ClientSession", Any]
@@ -62,6 +70,16 @@ async def _default_sampling_callback(
6270
)
6371

6472

73+
async def _default_elicitation_callback(
74+
context: RequestContext["ClientSession", Any],
75+
params: types.ElicitRequestParams,
76+
) -> types.ElicitResult | types.ErrorData:
77+
return types.ErrorData(
78+
code=types.INVALID_REQUEST,
79+
message="Elicitation not supported",
80+
)
81+
82+
6583
async def _default_list_roots_callback(
6684
context: RequestContext["ClientSession", Any],
6785
) -> types.ListRootsResult | types.ErrorData:
@@ -97,6 +115,7 @@ def __init__(
97115
write_stream: MemoryObjectSendStream[SessionMessage],
98116
read_timeout_seconds: timedelta | None = None,
99117
sampling_callback: SamplingFnT | None = None,
118+
elicitation_callback: ElicitationFnT | None = None,
100119
list_roots_callback: ListRootsFnT | None = None,
101120
logging_callback: LoggingFnT | None = None,
102121
message_handler: MessageHandlerFnT | None = None,
@@ -111,12 +130,16 @@ def __init__(
111130
)
112131
self._client_info = client_info or DEFAULT_CLIENT_INFO
113132
self._sampling_callback = sampling_callback or _default_sampling_callback
133+
self._elicitation_callback = (
134+
elicitation_callback or _default_elicitation_callback
135+
)
114136
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
115137
self._logging_callback = logging_callback or _default_logging_callback
116138
self._message_handler = message_handler or _default_message_handler
117139

118140
async def initialize(self) -> types.InitializeResult:
119141
sampling = types.SamplingCapability()
142+
elicitation = types.ElicitationCapability()
120143
roots = types.RootsCapability(
121144
# TODO: Should this be based on whether we
122145
# _will_ send notifications, or only whether
@@ -132,6 +155,7 @@ async def initialize(self) -> types.InitializeResult:
132155
protocolVersion=types.LATEST_PROTOCOL_VERSION,
133156
capabilities=types.ClientCapabilities(
134157
sampling=sampling,
158+
elicitation=elicitation,
135159
experimental=None,
136160
roots=roots,
137161
),
@@ -355,6 +379,12 @@ async def _received_request(
355379
client_response = ClientResponse.validate_python(response)
356380
await responder.respond(client_response)
357381

382+
case types.ElicitRequest(params=params):
383+
with responder:
384+
response = await self._elicitation_callback(ctx, params)
385+
client_response = ClientResponse.validate_python(response)
386+
await responder.respond(client_response)
387+
358388
case types.ListRootsRequest():
359389
with responder:
360390
response = await self._list_roots_callback(ctx)

src/mcp/server/fastmcp/server.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,39 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
822822
), "Context is not available outside of a request"
823823
return await self._fastmcp.read_resource(uri)
824824

825+
async def elicit(
826+
self,
827+
message: str,
828+
requestedSchema: dict[str, Any],
829+
) -> dict[str, Any]:
830+
"""Elicit information from the client/user.
831+
832+
This method can be used to interactively ask for additional information from the
833+
client within a tool's execution.
834+
The client might display the message to the user and collect a response
835+
according to the provided schema. Or in case a client is an agent, it might
836+
decide how to handle the elicitation -- either by asking the user or
837+
automatically generating a response.
838+
839+
Args:
840+
message: The message to present to the user
841+
requestedSchema: JSON Schema defining the expected response structure
842+
843+
Returns:
844+
The user's response as a dict matching the request schema structure
845+
846+
Raises:
847+
ValueError: If elicitation is not supported by the client or fails
848+
"""
849+
850+
result = await self.request_context.session.elicit(
851+
message=message,
852+
requestedSchema=requestedSchema,
853+
related_request_id=self.request_id,
854+
)
855+
856+
return result.response
857+
825858
async def log(
826859
self,
827860
level: Literal["debug", "info", "warning", "error"],

src/mcp/server/session.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
50-
from mcp.shared.message import SessionMessage
50+
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5151
from mcp.shared.session import (
5252
BaseSession,
5353
RequestResponder,
@@ -128,6 +128,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
128128
if client_caps.sampling is None:
129129
return False
130130

131+
if capability.elicitation is not None:
132+
if client_caps.elicitation is None:
133+
return False
134+
131135
if capability.experimental is not None:
132136
if client_caps.experimental is None:
133137
return False
@@ -262,6 +266,35 @@ async def list_roots(self) -> types.ListRootsResult:
262266
types.ListRootsResult,
263267
)
264268

269+
async def elicit(
270+
self,
271+
message: str,
272+
requestedSchema: dict[str, Any],
273+
related_request_id: types.RequestId | None = None,
274+
) -> types.ElicitResult:
275+
"""Send an elicitation/create request.
276+
277+
Args:
278+
message: The message to present to the user
279+
requestedSchema: JSON Schema defining the expected response structure
280+
281+
Returns:
282+
The client's response
283+
"""
284+
return await self.send_request(
285+
types.ServerRequest(
286+
types.ElicitRequest(
287+
method="elicitation/create",
288+
params=types.ElicitRequestParams(
289+
message=message,
290+
requestedSchema=requestedSchema,
291+
),
292+
)
293+
),
294+
types.ElicitResult,
295+
metadata=ServerMessageMetadata(related_request_id=related_request_id),
296+
)
297+
265298
async def send_ping(self) -> types.EmptyResult:
266299
"""Send a ping request."""
267300
return await self.send_request(

src/mcp/types.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,13 @@ class RootsCapability(BaseModel):
205205

206206

207207
class SamplingCapability(BaseModel):
208-
"""Capability for logging operations."""
208+
"""Capability for sampling operations."""
209+
210+
model_config = ConfigDict(extra="allow")
211+
212+
213+
class ElicitationCapability(BaseModel):
214+
"""Capability for elicitation operations."""
209215

210216
model_config = ConfigDict(extra="allow")
211217

@@ -217,6 +223,8 @@ class ClientCapabilities(BaseModel):
217223
"""Experimental, non-standard capabilities that the client supports."""
218224
sampling: SamplingCapability | None = None
219225
"""Present if the client supports sampling from an LLM."""
226+
elicitation: ElicitationCapability | None = None
227+
"""Present if the client supports elicitation from the user."""
220228
roots: RootsCapability | None = None
221229
"""Present if the client supports listing roots."""
222230
model_config = ConfigDict(extra="allow")
@@ -1141,11 +1149,42 @@ class ClientNotification(
11411149
pass
11421150

11431151

1144-
class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]):
1152+
class ElicitRequestParams(RequestParams):
1153+
"""Parameters for elicitation requests."""
1154+
1155+
message: str
1156+
"""The message to present to the user."""
1157+
1158+
requestedSchema: dict[str, Any]
1159+
"""
1160+
A JSON Schema object defining the expected structure of the response.
1161+
"""
1162+
model_config = ConfigDict(extra="allow")
1163+
1164+
1165+
class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]):
1166+
"""A request from the server to elicit information from the client."""
1167+
1168+
method: Literal["elicitation/create"]
1169+
params: ElicitRequestParams
1170+
1171+
1172+
class ElicitResult(Result):
1173+
"""The client's response to an elicitation/create request from the server."""
1174+
1175+
response: dict[str, Any]
1176+
"""The response from the client, matching the structure of requestedSchema."""
1177+
1178+
1179+
class ClientResult(
1180+
RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]
1181+
):
11451182
pass
11461183

11471184

1148-
class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]):
1185+
class ServerRequest(
1186+
RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]
1187+
):
11491188
pass
11501189

11511190

tests/server/fastmcp/test_integration.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
from mcp.client.session import ClientSession
1717
from mcp.client.sse import sse_client
18-
from mcp.server.fastmcp import FastMCP
19-
from mcp.types import InitializeResult, TextContent
18+
from mcp.server.fastmcp import Context, FastMCP
19+
from mcp.types import InitializeResult, TextContent, ElicitResult
2020

2121

2222
@pytest.fixture
@@ -45,6 +45,23 @@ def make_fastmcp_app():
4545
def echo(message: str) -> str:
4646
return f"Echo: {message}"
4747

48+
# Add a tool that uses elicitation
49+
@mcp.tool(description="A tool that uses elicitation")
50+
async def ask_user(prompt: str, ctx: Context) -> str:
51+
schema = {
52+
"type": "object",
53+
"properties": {
54+
"answer": {"type": "string"},
55+
},
56+
"required": ["answer"],
57+
}
58+
59+
response = await ctx.elicit(
60+
message=f"Tool wants to ask: {prompt}",
61+
requestedSchema=schema,
62+
)
63+
return f"User answered: {response['answer']}"
64+
4865
# Create the SSE app
4966
app: Starlette = mcp.sse_app()
5067

@@ -110,3 +127,36 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
110127
assert len(tool_result.content) == 1
111128
assert isinstance(tool_result.content[0], TextContent)
112129
assert tool_result.content[0].text == "Echo: hello"
130+
131+
132+
@pytest.mark.anyio
133+
async def test_elicitation_feature(server: None, server_url: str) -> None:
134+
"""Test the elicitation feature."""
135+
136+
# Create a custom handler for elicitation requests
137+
async def elicitation_callback(context, params):
138+
# Verify the elicitation parameters
139+
if params.message == "Tool wants to ask: What is your name?":
140+
return ElicitResult(response={"answer": "Test User"})
141+
else:
142+
raise ValueError("Unexpected elicitation message")
143+
144+
# Connect to the server with our custom elicitation handler
145+
async with sse_client(server_url + "/sse") as streams:
146+
async with ClientSession(
147+
*streams, elicitation_callback=elicitation_callback
148+
) as session:
149+
# First initialize the session
150+
result = await session.initialize()
151+
assert isinstance(result, InitializeResult)
152+
assert result.serverInfo.name == "NoAuthServer"
153+
154+
# Call the tool that uses elicitation
155+
tool_result = await session.call_tool(
156+
"ask_user", {"prompt": "What is your name?"}
157+
)
158+
# Verify the result
159+
assert len(tool_result.content) == 1
160+
assert isinstance(tool_result.content[0], TextContent)
161+
# # The test should only succeed with the successful elicitation response
162+
assert tool_result.content[0].text == "User answered: Test User"

0 commit comments

Comments
 (0)