Skip to content

Commit 2b61abb

Browse files
committed
Reimplement test_client_tool_call_with_meta to goes through all the protocol phases
1 parent 701611d commit 2b61abb

File tree

1 file changed

+72
-12
lines changed

1 file changed

+72
-12
lines changed

tests/client/test_session.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -507,11 +507,47 @@ async def mock_server():
507507
@pytest.mark.anyio
508508
@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}])
509509
async def test_client_tool_call_with_meta(meta: dict[str, Any] | None):
510-
"""Test that client tool call requests can include metadata."""
510+
"""Test that client tool call requests can include metadata"""
511511
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
512512
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
513513

514+
mocked_tool = types.Tool(name="sample_tool", inputSchema={})
515+
514516
async def mock_server():
517+
# Receive initialization request from client
518+
session_message = await client_to_server_receive.receive()
519+
jsonrpc_request = session_message.message
520+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
521+
request = ClientRequest.model_validate(
522+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
523+
)
524+
assert isinstance(request.root, InitializeRequest)
525+
526+
result = ServerResult(
527+
InitializeResult(
528+
protocolVersion=LATEST_PROTOCOL_VERSION,
529+
capabilities=ServerCapabilities(),
530+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
531+
)
532+
)
533+
534+
# Answer initialization request
535+
await server_to_client_send.send(
536+
SessionMessage(
537+
JSONRPCMessage(
538+
JSONRPCResponse(
539+
jsonrpc="2.0",
540+
id=jsonrpc_request.root.id,
541+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
542+
)
543+
)
544+
)
545+
)
546+
547+
# Receive initialized notification
548+
await client_to_server_receive.receive()
549+
550+
# Wait for the client to send a 'tools/call' request
515551
session_message = await client_to_server_receive.receive()
516552
jsonrpc_request = session_message.message
517553
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
@@ -527,18 +563,42 @@ async def mock_server():
527563
CallToolResult(content=[TextContent(type="text", text="Called successfully")], isError=False)
528564
)
529565

530-
async with server_to_client_send:
531-
await server_to_client_send.send(
532-
SessionMessage(
533-
JSONRPCMessage(
534-
JSONRPCResponse(
535-
jsonrpc="2.0",
536-
id=jsonrpc_request.root.id,
537-
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
538-
)
566+
# Send the tools/call result
567+
await server_to_client_send.send(
568+
SessionMessage(
569+
JSONRPCMessage(
570+
JSONRPCResponse(
571+
jsonrpc="2.0",
572+
id=jsonrpc_request.root.id,
573+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
539574
)
540575
)
541576
)
577+
)
578+
579+
# Wait for the tools/list request from the client
580+
# The client requires this step to validate the tool output schema
581+
session_message = await client_to_server_receive.receive()
582+
jsonrpc_request = session_message.message
583+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
584+
585+
assert jsonrpc_request.root.method == "tools/list"
586+
587+
result = types.ListToolsResult(tools=[mocked_tool])
588+
589+
await server_to_client_send.send(
590+
SessionMessage(
591+
JSONRPCMessage(
592+
JSONRPCResponse(
593+
jsonrpc="2.0",
594+
id=jsonrpc_request.root.id,
595+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
596+
)
597+
)
598+
)
599+
)
600+
601+
server_to_client_send.close()
542602

543603
async with (
544604
ClientSession(
@@ -553,6 +613,6 @@ async def mock_server():
553613
):
554614
tg.start_soon(mock_server)
555615

556-
session._tool_output_schemas["sample_tool"] = None
616+
await session.initialize()
557617

558-
await session.call_tool(name="sample_tool", arguments={"foo": "bar"}, meta=meta)
618+
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)

0 commit comments

Comments
 (0)