Skip to content

Commit 83d23ad

Browse files
committed
Fixed uses of async generators and removed the pytest warning ignore
1 parent 5fec84c commit 83d23ad

File tree

4 files changed

+37
-37
lines changed

4 files changed

+37
-37
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,4 @@ filterwarnings = [
126126
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel",
127127
# This is to avoid test failures on Trio due to httpx's failure to explicitly close
128128
# async generators
129-
"ignore::pytest.PytestUnraisableExceptionWarning"
130129
]

src/mcp/client/streamable_http.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from contextlib import aclosing, asynccontextmanager
1212
from dataclasses import dataclass
1313
from datetime import timedelta
14-
from typing import cast
1514

1615
import anyio
1716
import httpx
@@ -241,15 +240,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
241240
event_source.response.raise_for_status()
242241
logger.debug("Resumption GET SSE connection established")
243242

244-
async for sse in event_source.aiter_sse():
245-
is_complete = await self._handle_sse_event(
246-
sse,
247-
ctx.read_stream_writer,
248-
original_request_id,
249-
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
250-
)
251-
if is_complete:
252-
break
243+
async with aclosing(event_source.aiter_sse()) as iterator:
244+
async for sse in iterator:
245+
is_complete = await self._handle_sse_event(
246+
sse,
247+
ctx.read_stream_writer,
248+
original_request_id,
249+
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
250+
)
251+
if is_complete:
252+
break
253253

254254
async def _handle_post_request(self, ctx: RequestContext) -> None:
255255
"""Handle a POST request with response processing."""
@@ -320,9 +320,7 @@ async def _handle_sse_response(
320320
) -> None:
321321
"""Handle SSE response from the server."""
322322
try:
323-
event_source = EventSource(response)
324-
sse_iter = cast(AsyncGenerator[ServerSentEvent], event_source.aiter_sse())
325-
async with aclosing(sse_iter) as items:
323+
async with aclosing(EventSource(response).aiter_sse()) as items:
326324
async for sse in items:
327325
is_complete = await self._handle_sse_event(
328326
sse,

tests/client/test_auth.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import base64
66
import hashlib
77
import time
8+
from contextlib import aclosing
89
from unittest.mock import AsyncMock, Mock, patch
910
from urllib.parse import parse_qs, urlparse
1011

@@ -654,17 +655,17 @@ async def test_async_auth_flow_401_response(self, oauth_provider, oauth_token):
654655
mock_response = Mock()
655656
mock_response.status_code = 401
656657

657-
auth_flow = oauth_provider.async_auth_flow(request)
658-
await auth_flow.__anext__()
658+
async with aclosing(oauth_provider.async_auth_flow(request)) as auth_flow:
659+
await auth_flow.__anext__()
659660

660-
# Send 401 response
661-
try:
662-
await auth_flow.asend(mock_response)
663-
except StopAsyncIteration:
664-
pass
661+
# Send 401 response
662+
try:
663+
await auth_flow.asend(mock_response)
664+
except StopAsyncIteration:
665+
pass
665666

666-
# Should clear current tokens
667-
assert oauth_provider._current_tokens is None
667+
# Should clear current tokens
668+
assert oauth_provider._current_tokens is None
668669

669670
@pytest.mark.anyio
670671
async def test_async_auth_flow_no_token(self, oauth_provider):
@@ -675,14 +676,14 @@ async def test_async_auth_flow_no_token(self, oauth_provider):
675676
patch.object(oauth_provider, "initialize") as mock_init,
676677
patch.object(oauth_provider, "ensure_token") as mock_ensure,
677678
):
678-
auth_flow = oauth_provider.async_auth_flow(request)
679-
updated_request = await auth_flow.__anext__()
679+
async with aclosing(oauth_provider.async_auth_flow(request)) as auth_flow:
680+
updated_request = await auth_flow.__anext__()
680681

681-
mock_init.assert_called_once()
682-
mock_ensure.assert_called_once()
682+
mock_init.assert_called_once()
683+
mock_ensure.assert_called_once()
683684

684-
# No Authorization header should be added if no token
685-
assert "Authorization" not in updated_request.headers
685+
# No Authorization header should be added if no token
686+
assert "Authorization" not in updated_request.headers
686687

687688
@pytest.mark.anyio
688689
async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info):

tests/shared/test_sse.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import socket
44
import time
55
from collections.abc import AsyncGenerator, Generator
6+
from contextlib import aclosing
67

78
import anyio
89
import httpx
@@ -160,14 +161,15 @@ async def connection_test() -> None:
160161
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
161162

162163
line_number = 0
163-
async for line in response.aiter_lines():
164-
if line_number == 0:
165-
assert line == "event: endpoint"
166-
elif line_number == 1:
167-
assert line.startswith("data: /messages/?session_id=")
168-
else:
169-
return
170-
line_number += 1
164+
async with aclosing(response.aiter_lines()) as lines:
165+
async for line in lines:
166+
if line_number == 0:
167+
assert line == "event: endpoint"
168+
elif line_number == 1:
169+
assert line.startswith("data: /messages/?session_id=")
170+
else:
171+
return
172+
line_number += 1
171173

172174
# Add timeout to prevent test from hanging if it fails
173175
with anyio.fail_after(3):

0 commit comments

Comments
 (0)