|
11 | 11 | from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS |
12 | 12 | from mcp.types import ( |
13 | 13 | LATEST_PROTOCOL_VERSION, |
| 14 | + CallToolResult, |
14 | 15 | ClientNotification, |
15 | 16 | ClientRequest, |
16 | 17 | Implementation, |
|
23 | 24 | JSONRPCResponse, |
24 | 25 | ServerCapabilities, |
25 | 26 | ServerResult, |
| 27 | + TextContent, |
26 | 28 | ) |
27 | 29 |
|
28 | 30 |
|
@@ -492,8 +494,125 @@ async def mock_server(): |
492 | 494 |
|
493 | 495 | # Assert that capabilities are properly set with custom callbacks |
494 | 496 | assert received_capabilities is not None |
495 | | - assert received_capabilities.sampling is not None # Custom sampling callback provided |
| 497 | + # Custom sampling callback provided |
| 498 | + assert received_capabilities.sampling is not None |
496 | 499 | assert isinstance(received_capabilities.sampling, types.SamplingCapability) |
497 | | - assert received_capabilities.roots is not None # Custom list_roots callback provided |
| 500 | + # Custom list_roots callback provided |
| 501 | + assert received_capabilities.roots is not None |
498 | 502 | assert isinstance(received_capabilities.roots, types.RootsCapability) |
499 | | - assert received_capabilities.roots.listChanged is True # Should be True for custom callback |
| 503 | + # Should be True for custom callback |
| 504 | + assert received_capabilities.roots.listChanged is True |
| 505 | + |
| 506 | + |
| 507 | +@pytest.mark.anyio |
| 508 | +@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) |
| 509 | +async def test_client_tool_call_with_meta(meta: dict[str, Any] | None): |
| 510 | + """Test that client tool call requests can include metadata""" |
| 511 | + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) |
| 512 | + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) |
| 513 | + |
| 514 | + mocked_tool = types.Tool(name="sample_tool", inputSchema={}) |
| 515 | + |
| 516 | + async def mock_server(): |
| 517 | + # Receive initialization request from client |
| 518 | + session_message = await client_to_server_receive.receive() |
| 519 | + jsonrpc_request = session_message.message |
| 520 | + assert isinstance(jsonrpc_request.root, JSONRPCRequest) |
| 521 | + request = ClientRequest.model_validate( |
| 522 | + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 523 | + ) |
| 524 | + assert isinstance(request.root, InitializeRequest) |
| 525 | + |
| 526 | + result = ServerResult( |
| 527 | + InitializeResult( |
| 528 | + protocolVersion=LATEST_PROTOCOL_VERSION, |
| 529 | + capabilities=ServerCapabilities(), |
| 530 | + serverInfo=Implementation(name="mock-server", version="0.1.0"), |
| 531 | + ) |
| 532 | + ) |
| 533 | + |
| 534 | + # Answer initialization request |
| 535 | + await server_to_client_send.send( |
| 536 | + SessionMessage( |
| 537 | + JSONRPCMessage( |
| 538 | + JSONRPCResponse( |
| 539 | + jsonrpc="2.0", |
| 540 | + id=jsonrpc_request.root.id, |
| 541 | + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), |
| 542 | + ) |
| 543 | + ) |
| 544 | + ) |
| 545 | + ) |
| 546 | + |
| 547 | + # Receive initialized notification |
| 548 | + await client_to_server_receive.receive() |
| 549 | + |
| 550 | + # Wait for the client to send a 'tools/call' request |
| 551 | + session_message = await client_to_server_receive.receive() |
| 552 | + jsonrpc_request = session_message.message |
| 553 | + assert isinstance(jsonrpc_request.root, JSONRPCRequest) |
| 554 | + |
| 555 | + assert jsonrpc_request.root.method == "tools/call" |
| 556 | + |
| 557 | + if meta is not None: |
| 558 | + assert jsonrpc_request.root.params |
| 559 | + assert "_meta" in jsonrpc_request.root.params |
| 560 | + assert jsonrpc_request.root.params["_meta"] == meta |
| 561 | + |
| 562 | + result = ServerResult( |
| 563 | + CallToolResult(content=[TextContent(type="text", text="Called successfully")], isError=False) |
| 564 | + ) |
| 565 | + |
| 566 | + # Send the tools/call result |
| 567 | + await server_to_client_send.send( |
| 568 | + SessionMessage( |
| 569 | + JSONRPCMessage( |
| 570 | + JSONRPCResponse( |
| 571 | + jsonrpc="2.0", |
| 572 | + id=jsonrpc_request.root.id, |
| 573 | + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), |
| 574 | + ) |
| 575 | + ) |
| 576 | + ) |
| 577 | + ) |
| 578 | + |
| 579 | + # Wait for the tools/list request from the client |
| 580 | + # The client requires this step to validate the tool output schema |
| 581 | + session_message = await client_to_server_receive.receive() |
| 582 | + jsonrpc_request = session_message.message |
| 583 | + assert isinstance(jsonrpc_request.root, JSONRPCRequest) |
| 584 | + |
| 585 | + assert jsonrpc_request.root.method == "tools/list" |
| 586 | + |
| 587 | + result = types.ListToolsResult(tools=[mocked_tool]) |
| 588 | + |
| 589 | + await server_to_client_send.send( |
| 590 | + SessionMessage( |
| 591 | + JSONRPCMessage( |
| 592 | + JSONRPCResponse( |
| 593 | + jsonrpc="2.0", |
| 594 | + id=jsonrpc_request.root.id, |
| 595 | + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), |
| 596 | + ) |
| 597 | + ) |
| 598 | + ) |
| 599 | + ) |
| 600 | + |
| 601 | + server_to_client_send.close() |
| 602 | + |
| 603 | + async with ( |
| 604 | + ClientSession( |
| 605 | + server_to_client_receive, |
| 606 | + client_to_server_send, |
| 607 | + ) as session, |
| 608 | + anyio.create_task_group() as tg, |
| 609 | + client_to_server_send, |
| 610 | + client_to_server_receive, |
| 611 | + server_to_client_send, |
| 612 | + server_to_client_receive, |
| 613 | + ): |
| 614 | + tg.start_soon(mock_server) |
| 615 | + |
| 616 | + await session.initialize() |
| 617 | + |
| 618 | + await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) |
0 commit comments