Skip to content

Commit ba0f9ee

Browse files
committed
formatting and other misc tidy up highlighted by ruff check
1 parent 9802041 commit ba0f9ee

File tree

3 files changed

+104
-88
lines changed

3 files changed

+104
-88
lines changed

src/mcp/server/lowlevel/result_cache.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ class InProgress:
2626
user: AuthenticatedUser | None = None
2727
future: Future[types.CallToolResult] | None = None
2828
sessions: dict[int, ServerSession] = field(default_factory=lambda: {})
29-
session_progress: dict[int, types.ProgressToken | None] = field(default_factory=lambda: {})
29+
session_progress: dict[int, types.ProgressToken | None] = field(
30+
default_factory=lambda: {}
31+
)
32+
3033

3134
class ResultCache:
3235
"""
@@ -100,9 +103,14 @@ async def call_tool():
100103
return result.root
101104

102105
in_progress.user = user_context.get()
103-
in_progress.sessions[id(ctx.session)] = ctx.session
104-
in_progress.session_progress[id(ctx.session)] = None if req.params.meta is None else req.params.meta.progressToken
105-
self._session_lookup[id(ctx.session)] = in_progress.token
106+
session_id = id(ctx.session)
107+
in_progress.sessions[session_id] = ctx.session
108+
if req.params.meta is not None:
109+
progress_token = req.params.meta.progressToken
110+
else:
111+
progress_token = None
112+
in_progress.session_progress[session_id] = progress_token
113+
self._session_lookup[session_id] = in_progress.token
106114
in_progress.future = self._portal.start_task_soon(call_tool)
107115
result = types.CallToolAsyncResult(
108116
token=in_progress.token,
@@ -127,10 +135,15 @@ async def join_call(
127135
else:
128136
# TODO consider adding authorisation layer to make this decision
129137
if in_progress.user == user_context.get():
130-
logger.debug(f"Received join from {id(ctx.session)}")
131-
self._session_lookup[id(ctx.session)] = req.params.token
132-
in_progress.sessions[id(ctx.session)] = ctx.session
133-
in_progress.session_progress[id(ctx.session)] = None if req.params.meta is None else req.params.meta.progressToken
138+
session_id = id(ctx.session)
139+
logger.debug(f"Received join from {session_id}")
140+
self._session_lookup[session_id] = req.params.token
141+
in_progress.sessions[session_id] = ctx.session
142+
if req.params.meta is not None:
143+
progress_token = req.params.meta.progressToken
144+
else:
145+
progress_token = None
146+
in_progress.session_progress[session_id] = progress_token
134147
return types.CallToolAsyncResult(token=req.params.token, accepted=True)
135148
else:
136149
# TODO consider sending error via get result
@@ -196,25 +209,22 @@ async def notification_hook(
196209
logger.debug("Discarding progress notification from unknown session")
197210
else:
198211
in_progress = self._in_progress.get(async_token)
199-
if in_progress is None:
200-
# this should not happen
201-
logger.error("Discarding progress notification, not async")
202-
else:
203-
for session_id, other_session in in_progress.sessions.items():
204-
logger.debug(f"Checking {session_id} == {id(session)}")
205-
if not session_id == id(session):
206-
logger.debug(f"Sending progress to {id(other_session)}")
207-
progress_token = in_progress.session_progress.get(id(other_session))
208-
assert progress_token is not None
209-
await other_session.send_progress_notification(
210-
# TODO this token is incorrect
211-
# it needs to be collected from original request
212-
progress_token=progress_token,
213-
progress=notification.root.params.progress,
214-
total=notification.root.params.total,
215-
message=notification.root.params.message,
216-
resource_uri=notification.root.params.resourceUri,
217-
)
212+
assert in_progress is not None
213+
for other_id, other_session in in_progress.sessions.items():
214+
logger.debug(f"Checking {other_id} == {id(session)}")
215+
if not other_id == id(session):
216+
logger.debug(f"Sending progress to {other_id}")
217+
progress_token = in_progress.session_progress.get(other_id)
218+
assert progress_token is not None
219+
await other_session.send_progress_notification(
220+
# TODO this token is incorrect
221+
# it needs to be collected from original request
222+
progress_token=progress_token,
223+
progress=notification.root.params.progress,
224+
total=notification.root.params.total,
225+
message=notification.root.params.message,
226+
resource_uri=notification.root.params.resourceUri,
227+
)
218228

219229
async def session_close_hook(self, session: ServerSession):
220230
logger.debug(f"Closing {id(session)}")
Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
1+
from contextlib import AsyncExitStack
2+
from unittest.mock import AsyncMock, Mock
3+
14
import pytest
5+
26
from mcp import types
37
from mcp.server.lowlevel.result_cache import ResultCache
4-
from unittest.mock import AsyncMock, Mock, patch
5-
from contextlib import AsyncExitStack
8+
69

710
@pytest.mark.anyio
811
async def test_async_call():
912
"""Tests basic async call"""
13+
1014
async def test_call(call: types.CallToolRequest) -> types.ServerResult:
11-
return types.ServerResult(types.CallToolResult(
12-
content=[types.TextContent(
13-
type="text",
14-
text="test"
15-
)]
16-
))
17-
async_call = types.CallToolAsyncRequest(
18-
method="tools/async/call",
19-
params=types.CallToolAsyncRequestParams(
20-
name="test"
15+
return types.ServerResult(
16+
types.CallToolResult(content=[types.TextContent(type="text", text="test")])
2117
)
18+
19+
async_call = types.CallToolAsyncRequest(
20+
method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test")
2221
)
2322

2423
mock_session = AsyncMock()
@@ -27,37 +26,38 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
2726
result_cache = ResultCache(max_size=1, max_keep_alive=1)
2827
async with AsyncExitStack() as stack:
2928
await stack.enter_async_context(result_cache)
30-
async_call_ref = await result_cache.start_call(test_call, async_call, mock_context)
29+
async_call_ref = await result_cache.start_call(
30+
test_call, async_call, mock_context
31+
)
3132
assert async_call_ref.token is not None
3233

33-
result = await result_cache.get_result(types.GetToolAsyncResultRequest(
34-
method="tools/async/get",
35-
params=types.GetToolAsyncResultRequestParams(
36-
token = async_call_ref.token
34+
result = await result_cache.get_result(
35+
types.GetToolAsyncResultRequest(
36+
method="tools/async/get",
37+
params=types.GetToolAsyncResultRequestParams(
38+
token=async_call_ref.token
39+
),
3740
)
38-
))
41+
)
3942

4043
assert not result.isError
4144
assert not result.isPending
4245
assert len(result.content) == 1
4346
assert type(result.content[0]) is types.TextContent
4447
assert result.content[0].text == "test"
4548

49+
4650
@pytest.mark.anyio
4751
async def test_async_join_call_progress():
4852
"""Tests basic async call"""
53+
4954
async def test_call(call: types.CallToolRequest) -> types.ServerResult:
50-
return types.ServerResult(types.CallToolResult(
51-
content=[types.TextContent(
52-
type="text",
53-
text="test"
54-
)]
55-
))
56-
async_call = types.CallToolAsyncRequest(
57-
method="tools/async/call",
58-
params=types.CallToolAsyncRequestParams(
59-
name="test"
55+
return types.ServerResult(
56+
types.CallToolResult(content=[types.TextContent(type="text", text="test")])
6057
)
58+
59+
async_call = types.CallToolAsyncRequest(
60+
method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test")
6161
)
6262

6363
mock_session_1 = AsyncMock()
@@ -73,38 +73,42 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
7373
result_cache = ResultCache(max_size=1, max_keep_alive=1)
7474
async with AsyncExitStack() as stack:
7575
await stack.enter_async_context(result_cache)
76-
async_call_ref = await result_cache.start_call(test_call, async_call, mock_context_1)
76+
async_call_ref = await result_cache.start_call(
77+
test_call, async_call, mock_context_1
78+
)
7779
assert async_call_ref.token is not None
7880

7981
await result_cache.join_call(
8082
req=types.JoinCallToolAsyncRequest(
8183
method="tools/async/join",
8284
params=types.JoinCallToolRequestParams(
8385
token=async_call_ref.token,
84-
_meta = types.RequestParams.Meta(
85-
progressToken="test"
86-
)
87-
)
86+
_meta=types.RequestParams.Meta(progressToken="test"),
87+
),
8888
),
89-
ctx=mock_context_2
89+
ctx=mock_context_2,
9090
)
9191
assert async_call_ref.token is not None
9292
await result_cache.notification_hook(
93-
session=mock_session_1,
94-
notification=types.ServerNotification(types.ProgressNotification(
95-
method="notifications/progress",
96-
params=types.ProgressNotificationParams(
97-
progressToken="test",
98-
progress=1
93+
session=mock_session_1,
94+
notification=types.ServerNotification(
95+
types.ProgressNotification(
96+
method="notifications/progress",
97+
params=types.ProgressNotificationParams(
98+
progressToken="test", progress=1
99+
),
99100
)
100-
)))
101+
),
102+
)
101103

102-
result = await result_cache.get_result(types.GetToolAsyncResultRequest(
103-
method="tools/async/get",
104-
params=types.GetToolAsyncResultRequestParams(
105-
token = async_call_ref.token
104+
result = await result_cache.get_result(
105+
types.GetToolAsyncResultRequest(
106+
method="tools/async/get",
107+
params=types.GetToolAsyncResultRequestParams(
108+
token=async_call_ref.token
109+
),
106110
)
107-
))
111+
)
108112

109113
assert not result.isError
110114
assert not result.isPending
@@ -113,9 +117,9 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
113117
assert result.content[0].text == "test"
114118
mock_context_1.send_progress_notification.assert_not_called()
115119
mock_session_2.send_progress_notification.assert_called_with(
116-
progress_token="test",
117-
progress=1.0,
118-
total=None,
119-
message=None,
120-
resource_uri = None
120+
progress_token="test",
121+
progress=1.0,
122+
total=None,
123+
message=None,
124+
resource_uri=None,
121125
)

tests/shared/test_session.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import AsyncGenerator
2+
from logging import getLogger
23

34
import anyio
45
import pytest
@@ -19,6 +20,8 @@
1920
EmptyResult,
2021
)
2122

23+
logger = getLogger(__name__)
24+
2225

2326
@pytest.fixture
2427
def mcp_server() -> Server:
@@ -211,15 +214,14 @@ async def get_result(client_session: ClientSession, async_token: types.AsyncToke
211214
assert type(result.content[0]) is types.TextContent
212215
assert result.content[0].text == "test"
213216

214-
from logging import getLogger
215-
216-
logger = getLogger(__name__)
217217

218218
@pytest.mark.anyio
219-
@pytest.mark.skip(reason="This test does not work, there is a subtle "
220-
"bug with event.wait, lower level test_result_cache "
221-
"tests underlying behaviour, revisit with feedback " \
222-
"from someone who cah help debug")
219+
@pytest.mark.skip(
220+
reason="This test does not work, there is a subtle "
221+
"bug with event.wait, lower level test_result_cache "
222+
"tests underlying behaviour, revisit with feedback "
223+
"from someone who cah help debug"
224+
)
223225
async def test_request_async_join():
224226
"""Test that requests can be joined from external sessions."""
225227
# The tool is already registered in the fixture
@@ -236,7 +238,6 @@ async def test_request_async_join():
236238
ev_client_2_progressed_1 = anyio.Event()
237239
ev_done = anyio.Event()
238240

239-
240241
# Start the request in a separate task so we can cancel it
241242
def make_server() -> Server:
242243
server = Server(name="TestSessionServer")
@@ -328,8 +329,7 @@ async def client_2_progress_callback(
328329
logger.info(f"client2: progress done: {progress}/{total}")
329330

330331
async def join_request(
331-
client_session: ClientSession,
332-
async_token: types.AsyncToken
332+
client_session: ClientSession, async_token: types.AsyncToken
333333
):
334334
return await client_session.send_request(
335335
ClientRequest(
@@ -365,6 +365,7 @@ async def get_result(client_session: ClientSession, async_token: types.AsyncToke
365365
token = None
366366

367367
async with anyio.create_task_group() as tg:
368+
368369
async def client_1_submit():
369370
async with create_connected_server_and_client_session(
370371
server
@@ -430,6 +431,7 @@ async def client_2_join():
430431
assert ev_client_1_progressed_2.is_set()
431432
assert ev_client_2_progressed_1.is_set()
432433

434+
433435
@pytest.mark.anyio
434436
async def test_connection_closed():
435437
"""

0 commit comments

Comments
 (0)