Skip to content

Commit a7189d3

Browse files
committed
Add missing server to client task tests
1 parent c1966fe commit a7189d3

File tree

1 file changed

+261
-0
lines changed

1 file changed

+261
-0
lines changed
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""Tests for server-to-client task requests.
2+
3+
These tests verify that the client can properly handle task-related requests
4+
initiated by the server, which is part of bidirectional task-based execution.
5+
"""
6+
7+
from typing import Any
8+
9+
import pytest
10+
11+
import mcp.types as types
12+
from examples.shared.in_memory_task_store import InMemoryTaskStore
13+
from mcp.server import Server
14+
from mcp.shared.memory import create_connected_server_and_client_session
15+
16+
# Mark all tests in this module to ignore memory stream cleanup warnings
17+
pytestmark = pytest.mark.filterwarnings(
18+
"ignore:Exception ignored.*MemoryObject.*Stream:pytest.PytestUnraisableExceptionWarning"
19+
)
20+
21+
22+
@pytest.mark.anyio
23+
async def test_server_queries_client_get_task_success():
24+
"""Test server using session.get_task() to query client's task."""
25+
server_task_store = InMemoryTaskStore()
26+
client_task_store = InMemoryTaskStore()
27+
server = Server("test", task_store=server_task_store)
28+
29+
# Create a task in the CLIENT's store
30+
task_id = "client-task-123"
31+
task_meta = types.TaskMetadata(taskId=task_id, keepAlive=60000)
32+
request = types.ClientRequest(types.PingRequest())
33+
await client_task_store.create_task(task_meta, "req-1", request.root)
34+
35+
# Register list_tools to advertise our tool
36+
@server.list_tools()
37+
async def list_tools() -> list[types.Tool]:
38+
return [
39+
types.Tool(
40+
name="query_client_task",
41+
description="Query a task from the client",
42+
inputSchema={"type": "object", "properties": {}, "required": []},
43+
)
44+
]
45+
46+
# Create a tool that will use the server session to query the client
47+
@server.call_tool()
48+
async def query_client_task(name: str, arguments: dict[str, Any]) -> list[types.TextContent]:
49+
# Server session queries client's task via request_context
50+
session = server.request_context.session
51+
result = await session.get_task(task_id)
52+
return [types.TextContent(type="text", text=f"Task {result.taskId} status: {result.status}")]
53+
54+
async with create_connected_server_and_client_session(
55+
server,
56+
# Client needs task store to handle the GetTaskRequest
57+
) as client_session:
58+
# Set the client's task store
59+
client_session._task_store = client_task_store
60+
61+
await client_session.initialize()
62+
63+
# Call the tool which will trigger server->client task query
64+
result = await client_session.call_tool("query_client_task", {})
65+
66+
assert len(result.content) == 1
67+
assert isinstance(result.content[0], types.TextContent)
68+
assert "client-task-123" in result.content[0].text
69+
assert "submitted" in result.content[0].text
70+
71+
72+
@pytest.mark.anyio
73+
async def test_server_queries_client_get_task_not_found():
74+
"""Test server querying client for non-existent task."""
75+
server_task_store = InMemoryTaskStore()
76+
client_task_store = InMemoryTaskStore()
77+
server = Server("test", task_store=server_task_store)
78+
79+
@server.list_tools()
80+
async def list_tools() -> list[types.Tool]:
81+
return [
82+
types.Tool(
83+
name="query_nonexistent_task",
84+
description="Query nonexistent task",
85+
inputSchema={"type": "object", "properties": {}, "required": []},
86+
)
87+
]
88+
89+
@server.call_tool()
90+
async def query_nonexistent_task(name: str, arguments: dict[str, Any]) -> list[types.TextContent]:
91+
try:
92+
session = server.request_context.session
93+
await session.get_task("nonexistent-task")
94+
return [types.TextContent(type="text", text="Should have failed")]
95+
except Exception as e:
96+
return [types.TextContent(type="text", text=f"Error: {str(e)}")]
97+
98+
async with create_connected_server_and_client_session(server) as client_session:
99+
client_session._task_store = client_task_store
100+
await client_session.initialize()
101+
102+
result = await client_session.call_tool("query_nonexistent_task", {})
103+
104+
assert len(result.content) == 1
105+
assert isinstance(result.content[0], types.TextContent)
106+
assert "Error" in result.content[0].text
107+
assert "Task not found" in result.content[0].text or "INVALID_PARAMS" in result.content[0].text
108+
109+
110+
@pytest.mark.anyio
111+
async def test_server_queries_client_without_task_store():
112+
"""Test server querying client without task store configured."""
113+
server = Server("test", task_store=InMemoryTaskStore())
114+
115+
@server.list_tools()
116+
async def list_tools() -> list[types.Tool]:
117+
return [
118+
types.Tool(
119+
name="query_without_store",
120+
description="Query without task store",
121+
inputSchema={"type": "object", "properties": {}, "required": []},
122+
)
123+
]
124+
125+
@server.call_tool()
126+
async def query_without_store(name: str, arguments: dict[str, Any]) -> list[types.TextContent]:
127+
try:
128+
session = server.request_context.session
129+
await session.get_task("some-task")
130+
return [types.TextContent(type="text", text="Should have failed")]
131+
except Exception as e:
132+
return [types.TextContent(type="text", text=f"Error: {str(e)}")]
133+
134+
async with create_connected_server_and_client_session(server) as client_session:
135+
# No task store on client
136+
await client_session.initialize()
137+
138+
result = await client_session.call_tool("query_without_store", {})
139+
140+
assert len(result.content) == 1
141+
assert isinstance(result.content[0], types.TextContent)
142+
assert "Error" in result.content[0].text
143+
assert "Task store not configured" in result.content[0].text or "INVALID_REQUEST" in result.content[0].text
144+
145+
146+
@pytest.mark.anyio
147+
async def test_server_queries_client_get_task_result_success():
148+
"""Test server querying client for completed task result."""
149+
server_task_store = InMemoryTaskStore()
150+
client_task_store = InMemoryTaskStore()
151+
server = Server("test", task_store=server_task_store)
152+
153+
# Create a completed task with result in CLIENT's store
154+
# This is a server-to-client request (like sampling), so result is ClientResult
155+
task_id = "client-completed-task"
156+
task_meta = types.TaskMetadata(taskId=task_id)
157+
request = types.ServerRequest(types.PingRequest())
158+
await client_task_store.create_task(task_meta, "req-1", request.root)
159+
result = types.ClientResult(types.EmptyResult())
160+
await client_task_store.store_task_result(task_id, result.root)
161+
await client_task_store.update_task_status(task_id, "completed")
162+
163+
@server.list_tools()
164+
async def list_tools() -> list[types.Tool]:
165+
return [
166+
types.Tool(
167+
name="query_client_result",
168+
description="Query client task result",
169+
inputSchema={"type": "object", "properties": {}, "required": []},
170+
)
171+
]
172+
173+
@server.call_tool()
174+
async def query_client_result(name: str, arguments: dict[str, Any]) -> list[types.TextContent]:
175+
session = server.request_context.session
176+
result = await session.get_task_result(task_id, types.ClientResult)
177+
return [types.TextContent(type="text", text=f"Got result: {type(result.root).__name__}")] # type: ignore[attr-defined]
178+
179+
async with create_connected_server_and_client_session(server) as client_session:
180+
client_session._task_store = client_task_store
181+
await client_session.initialize()
182+
183+
result = await client_session.call_tool("query_client_result", {})
184+
185+
assert len(result.content) == 1
186+
assert isinstance(result.content[0], types.TextContent)
187+
assert "EmptyResult" in result.content[0].text
188+
189+
190+
@pytest.mark.anyio
191+
async def test_server_queries_client_list_tasks():
192+
"""Test server querying client's task list."""
193+
server_task_store = InMemoryTaskStore()
194+
client_task_store = InMemoryTaskStore()
195+
server = Server("test", task_store=server_task_store)
196+
197+
# Create multiple tasks in CLIENT's store
198+
for i in range(3):
199+
task_id = f"client-task-{i}"
200+
task_meta = types.TaskMetadata(taskId=task_id, keepAlive=60000)
201+
request = types.ClientRequest(types.PingRequest())
202+
await client_task_store.create_task(task_meta, f"req-{i}", request.root)
203+
204+
@server.list_tools()
205+
async def list_tools() -> list[types.Tool]:
206+
return [
207+
types.Tool(
208+
name="list_client_tasks",
209+
description="List client tasks",
210+
inputSchema={"type": "object", "properties": {}, "required": []},
211+
)
212+
]
213+
214+
@server.call_tool()
215+
async def list_client_tasks(name: str, arguments: dict[str, Any]) -> list[types.TextContent]:
216+
session = server.request_context.session
217+
result = await session.list_tasks()
218+
return [types.TextContent(type="text", text=f"Found {len(result.tasks)} tasks")]
219+
220+
async with create_connected_server_and_client_session(server) as client_session:
221+
client_session._task_store = client_task_store
222+
await client_session.initialize()
223+
224+
result = await client_session.call_tool("list_client_tasks", {})
225+
226+
assert len(result.content) == 1
227+
assert isinstance(result.content[0], types.TextContent)
228+
assert "Found 3 tasks" in result.content[0].text
229+
230+
231+
@pytest.mark.anyio
232+
async def test_server_list_tasks_empty():
233+
"""Test server querying client with empty task list."""
234+
server = Server("test", task_store=InMemoryTaskStore())
235+
client_task_store = InMemoryTaskStore()
236+
237+
@server.list_tools()
238+
async def list_tools() -> list[types.Tool]:
239+
return [
240+
types.Tool(
241+
name="list_empty_tasks",
242+
description="List empty tasks",
243+
inputSchema={"type": "object", "properties": {}, "required": []},
244+
)
245+
]
246+
247+
@server.call_tool()
248+
async def list_empty_tasks(name: str, arguments: dict[str, Any]) -> list[types.TextContent]:
249+
session = server.request_context.session
250+
result = await session.list_tasks()
251+
return [types.TextContent(type="text", text=f"Found {len(result.tasks)} tasks")]
252+
253+
async with create_connected_server_and_client_session(server) as client_session:
254+
client_session._task_store = client_task_store
255+
await client_session.initialize()
256+
257+
result = await client_session.call_tool("list_empty_tasks", {})
258+
259+
assert len(result.content) == 1
260+
assert isinstance(result.content[0], types.TextContent)
261+
assert "Found 0 tasks" in result.content[0].text

0 commit comments

Comments
 (0)