Skip to content

Commit c3fc092

Browse files
fix: test + formatting
1 parent 0512742 commit c3fc092

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

src/mcp/shared/session.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,31 @@
77

88
import anyio
99
import httpx
10-
from anyio.streams.memory import (MemoryObjectReceiveStream,
11-
MemoryObjectSendStream)
10+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1211
from pydantic import BaseModel, ValidationError
1312
from typing_extensions import Self
1413

1514
from mcp.shared.exceptions import McpError
16-
from mcp.shared.message import (MessageMetadata, ServerMessageMetadata,
17-
SessionMessage)
18-
from mcp.types import (CONNECTION_CLOSED, INVALID_PARAMS,
19-
CancelledNotification, ClientNotification,
20-
ClientRequest, ClientResult, ErrorData, JSONRPCError,
21-
JSONRPCMessage, JSONRPCNotification, JSONRPCRequest,
22-
JSONRPCResponse, ProgressNotification, RequestParams,
23-
ServerNotification, ServerRequest, ServerResult)
15+
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
16+
from mcp.types import (
17+
CONNECTION_CLOSED,
18+
INVALID_PARAMS,
19+
CancelledNotification,
20+
ClientNotification,
21+
ClientRequest,
22+
ClientResult,
23+
ErrorData,
24+
JSONRPCError,
25+
JSONRPCMessage,
26+
JSONRPCNotification,
27+
JSONRPCRequest,
28+
JSONRPCResponse,
29+
ProgressNotification,
30+
RequestParams,
31+
ServerNotification,
32+
ServerRequest,
33+
ServerResult,
34+
)
2435

2536
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
2637
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)

tests/shared/test_progress_notifications.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -335,18 +335,14 @@ def mock_log_error(msg, *args):
335335
logged_errors.append(msg % args if args else msg)
336336

337337
# Create a progress callback that raises an exception
338-
async def failing_progress_callback(
339-
progress: float, total: float | None, message: str | None
340-
) -> None:
338+
async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None:
341339
raise ValueError("Progress callback failed!")
342340

343341
# Create a server with a tool that sends progress notifications
344342
server = Server(name="TestProgressServer")
345343

346344
@server.call_tool()
347-
async def handle_call_tool(
348-
name: str, arguments: dict | None
349-
) -> list[types.TextContent]:
345+
async def handle_call_tool(name: str, arguments: dict | None) -> list[types.TextContent]:
350346
if name == "progress_tool":
351347
# Send a progress notification
352348
await server.request_context.session.send_progress_notification(
@@ -376,9 +372,7 @@ async def handle_list_tools() -> list[types.Tool]:
376372
types.ClientRequest(
377373
types.CallToolRequest(
378374
method="tools/call",
379-
params=types.CallToolRequestParams(
380-
name="progress_tool", arguments={}
381-
),
375+
params=types.CallToolRequestParams(name="progress_tool", arguments={}),
382376
)
383377
),
384378
types.CallToolResult,
@@ -388,12 +382,9 @@ async def handle_list_tools() -> list[types.Tool]:
388382
# Verify the request completed successfully despite the callback failure
389383
assert len(result.content) == 1
390384
content = result.content[0]
391-
assert isinstance(content, TextContent)
385+
assert isinstance(content, types.TextContent)
392386
assert content.text == "progress_result"
393387

394388
# Check that a warning was logged for the progress callback exception
395389
assert len(logged_errors) > 0
396-
assert any(
397-
"Progress callback raised an exception" in warning
398-
for warning in logged_errors
399-
)
390+
assert any("Progress callback raised an exception" in warning for warning in logged_errors)

0 commit comments

Comments
 (0)