Skip to content

Commit 7f010a0

Browse files
committed
feat: Allow any value type in get_prompt arguments
The type hints for the `arguments` parameter in `ClientSession.get_prompt` and the corresponding `GetPromptRequestParams` were previously restricted to `dict[str, str]`. This was overly restrictive and prevented passing arguments with non-string values, such as numbers or booleans. This commit changes the type to `dict[str, Any]`, allowing for more flexible prompt argument structures. A corresponding test case has been added to verify that calling `get_prompt` with an integer argument now works as expected. Fixes #749
1 parent 6f43d1f commit 7f010a0

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

src/mcp/client/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu
344344
types.ListPromptsResult,
345345
)
346346

347-
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
347+
async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> types.GetPromptResult:
348348
"""Send a prompts/get request."""
349349
return await self.send_request(
350350
types.ClientRequest(

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ class GetPromptRequestParams(RequestParams):
639639

640640
name: str
641641
"""The name of the prompt or prompt template."""
642-
arguments: dict[str, str] | None = None
642+
arguments: dict[str, Any] | None = None
643643
"""Arguments to use for templating the prompt."""
644644
model_config = ConfigDict(extra="allow")
645645

tests/client/test_session.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,72 @@ async def mock_server():
495495
assert received_capabilities.roots is not None # Custom list_roots callback provided
496496
assert isinstance(received_capabilities.roots, types.RootsCapability)
497497
assert received_capabilities.roots.listChanged is True # Should be True for custom callback
498+
499+
500+
@pytest.mark.anyio
501+
async def test_get_prompt_with_non_string_argument():
502+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
503+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
504+
505+
async def mock_server():
506+
await client_to_server_receive.receive()
507+
508+
await server_to_client_send.send(
509+
SessionMessage(
510+
JSONRPCMessage(
511+
JSONRPCResponse(
512+
jsonrpc="2.0",
513+
id=0,
514+
result=InitializeResult(
515+
protocolVersion=LATEST_PROTOCOL_VERSION,
516+
capabilities=ServerCapabilities(),
517+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
518+
).model_dump(by_alias=True, mode="json", exclude_none=True),
519+
)
520+
)
521+
)
522+
)
523+
524+
await client_to_server_receive.receive()
525+
526+
# Receive get_prompt request
527+
session_message = await client_to_server_receive.receive()
528+
jsonrpc_request = session_message.message
529+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
530+
request = ClientRequest.model_validate(
531+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
532+
)
533+
assert isinstance(request.root, types.GetPromptRequest)
534+
assert request.root.params.arguments == {"employee_id": 77}
535+
536+
# Send get_prompt result
537+
await server_to_client_send.send(
538+
SessionMessage(
539+
JSONRPCMessage(
540+
JSONRPCResponse(
541+
jsonrpc="2.0",
542+
id=jsonrpc_request.root.id,
543+
result=types.GetPromptResult(
544+
messages=[
545+
types.PromptMessage(role="user", content=types.TextContent(type="text", text="..."))
546+
]
547+
).model_dump(by_alias=True, mode="json", exclude_none=True),
548+
)
549+
)
550+
)
551+
)
552+
553+
async with (
554+
ClientSession(
555+
server_to_client_receive,
556+
client_to_server_send,
557+
) as session,
558+
anyio.create_task_group() as tg,
559+
client_to_server_send,
560+
client_to_server_receive,
561+
server_to_client_send,
562+
server_to_client_receive,
563+
):
564+
tg.start_soon(mock_server)
565+
await session.initialize()
566+
await session.get_prompt("get_employee_profile", {"employee_id": 77})

0 commit comments

Comments
 (0)