From f164291483eed8bb36415c080e0f3e2bd7efee28 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 09:22:05 -0500 Subject: [PATCH 01/16] trying to test SSE --- tests/client/test_sse_attempt.py | 82 ++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 tests/client/test_sse_attempt.py diff --git a/tests/client/test_sse_attempt.py b/tests/client/test_sse_attempt.py new file mode 100644 index 000000000..8df152b5e --- /dev/null +++ b/tests/client/test_sse_attempt.py @@ -0,0 +1,82 @@ +import pytest +import anyio +from starlette.applications import Starlette +from starlette.routing import Mount, Route +import uvicorn +from mcp.client.sse import sse_client +from exceptiongroup import ExceptionGroup +import asyncio +import httpx +from httpx import ReadTimeout + +from mcp.server.sse import SseServerTransport + +@pytest.fixture +async def sse_server(): + + # Create an SSE transport at an endpoint + sse = SseServerTransport("/messages/") + + # Create Starlette routes for SSE and message handling + routes = [ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ] + # + # Create and run Starlette app + app = Starlette(routes=routes) + + # Define handler functions + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await app.run( + streams[0], streams[1], app.create_initialization_options() + ) + + uvicorn.run(app, host="127.0.0.1", port=34891) + + async def sse_handler(request): + response = httpx.Response(200, content_type="text/event-stream") + response.send_headers() + response.write("data: test\n\n") + await response.aclose() + + async with httpx.AsyncServer(sse_handler) as server: + yield server.url + + +@pytest.fixture +async def sse_client(): + async with sse_client("http://test/sse") as (read_stream, write_stream): + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + + return read_stream, write_stream + +@pytest.mark.anyio +async def test_sse_happy_path(monkeypatch): + # Mock httpx.AsyncClient to return our mock response + monkeypatch.setattr(httpx, "AsyncClient", MockClient) + + with pytest.raises(ReadTimeout) as exc_info: + async with sse_client( + "http://test/sse", + timeout=5, # Connection timeout - make this longer + sse_read_timeout=1 # Read timeout - this should trigger + ) as (read_stream, write_stream): + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + + error = exc_info.value + assert isinstance(error, ReadTimeout) + assert str(error) == "Read timeout" + +@pytest.mark.anyio +async def test_sse_read_timeouts(monkeypatch): + """Test that the SSE client properly handles read timeouts between SSE messages.""" From b0a6aafaf698ff19d9b73b7477c20c1fcac5bff0 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 09:24:44 -0500 Subject: [PATCH 02/16] WIP --- src/mcp/client/sse.py | 19 +++- tests/client/test_sse_attempt.py | 151 ++++++++++++++++++++----------- 2 files changed, 116 insertions(+), 54 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index abafacb96..e09f6c5bf 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -24,12 +24,20 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + client: httpx.AsyncClient | None = None, ): """ Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Args: + url: The URL to connect to + headers: Optional headers to send with the request + timeout: Connection timeout in seconds + sse_read_timeout: Read timeout in seconds + client: Optional httpx.AsyncClient instance to use for requests """ read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] @@ -43,7 +51,13 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: + if client is None: + client = httpx.AsyncClient(headers=headers) + should_close_client = True + else: + should_close_client = False + + try: async with aconnect_sse( client, "GET", @@ -137,6 +151,9 @@ async def post_writer(endpoint_url: str): yield read_stream, write_stream finally: tg.cancel_scope.cancel() + finally: + if should_close_client: + await client.aclose() finally: await read_stream_writer.aclose() await write_stream.aclose() diff --git a/tests/client/test_sse_attempt.py b/tests/client/test_sse_attempt.py index 8df152b5e..400546701 100644 --- a/tests/client/test_sse_attempt.py +++ b/tests/client/test_sse_attempt.py @@ -1,82 +1,127 @@ -import pytest import anyio +import pytest from starlette.applications import Starlette from starlette.routing import Mount, Route -import uvicorn -from mcp.client.sse import sse_client -from exceptiongroup import ExceptionGroup -import asyncio import httpx -from httpx import ReadTimeout +from httpx import ReadTimeout, ASGITransport +from mcp.client.sse import sse_client from mcp.server.sse import SseServerTransport +from mcp.types import JSONRPCMessage + @pytest.fixture -async def sse_server(): +async def sse_transport(): + """Fixture that creates an SSE transport instance.""" + return SseServerTransport("/messages/") - # Create an SSE transport at an endpoint - sse = SseServerTransport("/messages/") - # Create Starlette routes for SSE and message handling +@pytest.fixture +async def sse_app(sse_transport): + """Fixture that creates a Starlette app with SSE endpoints.""" + async def handle_sse(request): + """Handler for SSE connections.""" + async with sse_transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + client_to_server, server_to_client = streams + async for message in client_to_server: + # Echo messages back for testing + await server_to_client.send(message) + routes = [ Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), + Mount("/messages", app=sse_transport.handle_post_message), ] - # - # Create and run Starlette app - app = Starlette(routes=routes) - # Define handler functions - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + return Starlette(routes=routes) - uvicorn.run(app, host="127.0.0.1", port=34891) - async def sse_handler(request): - response = httpx.Response(200, content_type="text/event-stream") - response.send_headers() - response.write("data: test\n\n") - await response.aclose() +@pytest.fixture +async def test_client(sse_app): + """Create a test client with ASGI transport.""" + async with httpx.AsyncClient( + transport=ASGITransport(app=sse_app), + base_url="http://testserver", + ) as client: + yield client - async with httpx.AsyncServer(sse_handler) as server: - yield server.url +@pytest.mark.anyio +async def test_sse_connection(test_client): + """Test basic SSE connection and message exchange.""" + async with sse_client( + "http://testserver/sse", + headers={"Host": "testserver"}, + timeout=5, + client=test_client, + ) as (read_stream, write_stream): + # Send a test message + test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"}) + await write_stream.send(test_message) -@pytest.fixture -async def sse_client(): - async with sse_client("http://test/sse") as (read_stream, write_stream): + # Receive echoed message async with read_stream: - async for message in read_stream: - if isinstance(message, Exception): - raise message + message = await read_stream.__anext__() + assert isinstance(message, JSONRPCMessage) + assert message.model_dump() == test_message.model_dump() - return read_stream, write_stream @pytest.mark.anyio -async def test_sse_happy_path(monkeypatch): - # Mock httpx.AsyncClient to return our mock response - monkeypatch.setattr(httpx, "AsyncClient", MockClient) - - with pytest.raises(ReadTimeout) as exc_info: +async def test_sse_read_timeout(test_client): + """Test that SSE client properly handles read timeouts.""" + with pytest.raises(ReadTimeout): async with sse_client( - "http://test/sse", - timeout=5, # Connection timeout - make this longer - sse_read_timeout=1 # Read timeout - this should trigger + "http://testserver/sse", + headers={"Host": "testserver"}, + timeout=5, + sse_read_timeout=1, + client=test_client, ) as (read_stream, write_stream): async with read_stream: - async for message in read_stream: - if isinstance(message, Exception): - raise message + # This should timeout since no messages are being sent + await read_stream.__anext__() + + +@pytest.mark.anyio +async def test_sse_connection_error(test_client): + """Test SSE client behavior with connection errors.""" + with pytest.raises(httpx.HTTPError): + async with sse_client( + "http://testserver/nonexistent", + headers={"Host": "testserver"}, + timeout=5, + client=test_client, + ): + pass # Should not reach here - error = exc_info.value - assert isinstance(error, ReadTimeout) - assert str(error) == "Read timeout" @pytest.mark.anyio -async def test_sse_read_timeouts(monkeypatch): - """Test that the SSE client properly handles read timeouts between SSE messages.""" +async def test_sse_multiple_messages(test_client): + """Test sending and receiving multiple SSE messages.""" + async with sse_client( + "http://testserver/sse", + headers={"Host": "testserver"}, + timeout=5, + client=test_client, + ) as (read_stream, write_stream): + # Send multiple test messages + messages = [ + JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"}) + for i in range(3) + ] + + for msg in messages: + await write_stream.send(msg) + + # Receive all echoed messages + received = [] + async with read_stream: + for _ in range(len(messages)): + message = await read_stream.__anext__() + assert isinstance(message, JSONRPCMessage) + received.append(message) + + # Verify all messages were received in order + for sent, received in zip(messages, received): + assert sent.model_dump() == received.model_dump() \ No newline at end of file From 3f9f7c83110b6dc1fc2916aa7dcc0205e1af0746 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 09:27:42 -0500 Subject: [PATCH 03/16] WIP --- tests/client/test_sse_attempt.py | 48 +++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/tests/client/test_sse_attempt.py b/tests/client/test_sse_attempt.py index 400546701..1f856bbc8 100644 --- a/tests/client/test_sse_attempt.py +++ b/tests/client/test_sse_attempt.py @@ -4,6 +4,8 @@ from starlette.routing import Mount, Route import httpx from httpx import ReadTimeout, ASGITransport +from starlette.responses import Response +from sse_starlette.sse import EventSourceResponse from mcp.client.sse import sse_client from mcp.server.sse import SseServerTransport @@ -21,17 +23,33 @@ async def sse_app(sse_transport): """Fixture that creates a Starlette app with SSE endpoints.""" async def handle_sse(request): """Handler for SSE connections.""" - async with sse_transport.connect_sse( - request.scope, request.receive, request._send - ) as streams: - client_to_server, server_to_client = streams - async for message in client_to_server: - # Echo messages back for testing - await server_to_client.send(message) + async def event_generator(): + # Send initial connection event + yield { + "event": "endpoint", + "data": "/messages", + } + + # Keep connection alive + async with sse_transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + client_to_server, server_to_client = streams + async for message in client_to_server: + yield { + "event": "message", + "data": message.model_dump_json(), + } + + return EventSourceResponse(event_generator()) + + async def handle_post(request): + """Handler for POST messages.""" + return Response(status_code=200) routes = [ Route("/sse", endpoint=handle_sse), - Mount("/messages", app=sse_transport.handle_post_message), + Route("/messages", endpoint=handle_post, methods=["POST"]), ] return Starlette(routes=routes) @@ -40,9 +58,11 @@ async def handle_sse(request): @pytest.fixture async def test_client(sse_app): """Create a test client with ASGI transport.""" + transport = ASGITransport(app=sse_app) async with httpx.AsyncClient( - transport=ASGITransport(app=sse_app), + transport=transport, base_url="http://testserver", + timeout=5.0, ) as client: yield client @@ -53,7 +73,8 @@ async def test_sse_connection(test_client): async with sse_client( "http://testserver/sse", headers={"Host": "testserver"}, - timeout=5, + timeout=2, + sse_read_timeout=1, client=test_client, ) as (read_stream, write_stream): # Send a test message @@ -74,7 +95,7 @@ async def test_sse_read_timeout(test_client): async with sse_client( "http://testserver/sse", headers={"Host": "testserver"}, - timeout=5, + timeout=2, sse_read_timeout=1, client=test_client, ) as (read_stream, write_stream): @@ -90,7 +111,7 @@ async def test_sse_connection_error(test_client): async with sse_client( "http://testserver/nonexistent", headers={"Host": "testserver"}, - timeout=5, + timeout=2, client=test_client, ): pass # Should not reach here @@ -102,7 +123,8 @@ async def test_sse_multiple_messages(test_client): async with sse_client( "http://testserver/sse", headers={"Host": "testserver"}, - timeout=5, + timeout=2, + sse_read_timeout=1, client=test_client, ) as (read_stream, write_stream): # Send multiple test messages From a0e2f7fab793fff3bba8c14b4e30f94aa42d985f Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 09:38:06 -0500 Subject: [PATCH 04/16] WIP --- tests/client/test_sse_attempt.py | 218 +++++++++++++++++-------------- 1 file changed, 121 insertions(+), 97 deletions(-) diff --git a/tests/client/test_sse_attempt.py b/tests/client/test_sse_attempt.py index 1f856bbc8..7d7329145 100644 --- a/tests/client/test_sse_attempt.py +++ b/tests/client/test_sse_attempt.py @@ -1,4 +1,5 @@ import anyio +import asyncio import pytest from starlette.applications import Starlette from starlette.routing import Mount, Route @@ -24,32 +25,42 @@ async def sse_app(sse_transport): async def handle_sse(request): """Handler for SSE connections.""" async def event_generator(): - # Send initial connection event - yield { - "event": "endpoint", - "data": "/messages", - } - - # Keep connection alive - async with sse_transport.connect_sse( - request.scope, request.receive, request._send - ) as streams: - client_to_server, server_to_client = streams - async for message in client_to_server: + try: + async with sse_transport.connect_sse( + request.scope, request.receive, request._send + ) as streams: + client_to_server, server_to_client = streams + # Send initial connection event yield { - "event": "message", - "data": message.model_dump_json(), + "event": "endpoint", + "data": "/messages", } - return EventSourceResponse(event_generator()) + # Process messages + async with anyio.create_task_group() as tg: + try: + async for message in client_to_server: + if isinstance(message, Exception): + break + yield { + "event": "message", + "data": message.model_dump_json(), + } + except (asyncio.CancelledError, GeneratorExit): + print('cancelled') + return + except Exception as e: + print("unhandled exception:", e) + return + except Exception: + # Log any unexpected errors but allow connection to close gracefully + pass - async def handle_post(request): - """Handler for POST messages.""" - return Response(status_code=200) + return EventSourceResponse(event_generator()) routes = [ Route("/sse", endpoint=handle_sse), - Route("/messages", endpoint=handle_post, methods=["POST"]), + Mount("/messages", app=sse_transport.handle_post_message), ] return Starlette(routes=routes) @@ -62,7 +73,7 @@ async def test_client(sse_app): async with httpx.AsyncClient( transport=transport, base_url="http://testserver", - timeout=5.0, + timeout=10.0, ) as client: yield client @@ -70,80 +81,93 @@ async def test_client(sse_app): @pytest.mark.anyio async def test_sse_connection(test_client): """Test basic SSE connection and message exchange.""" - async with sse_client( - "http://testserver/sse", - headers={"Host": "testserver"}, - timeout=2, - sse_read_timeout=1, - client=test_client, - ) as (read_stream, write_stream): - # Send a test message - test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"}) - await write_stream.send(test_message) - - # Receive echoed message - async with read_stream: - message = await read_stream.__anext__() - assert isinstance(message, JSONRPCMessage) - assert message.model_dump() == test_message.model_dump() - - -@pytest.mark.anyio -async def test_sse_read_timeout(test_client): - """Test that SSE client properly handles read timeouts.""" - with pytest.raises(ReadTimeout): - async with sse_client( - "http://testserver/sse", - headers={"Host": "testserver"}, - timeout=2, - sse_read_timeout=1, - client=test_client, - ) as (read_stream, write_stream): - async with read_stream: - # This should timeout since no messages are being sent - await read_stream.__anext__() - - -@pytest.mark.anyio -async def test_sse_connection_error(test_client): - """Test SSE client behavior with connection errors.""" - with pytest.raises(httpx.HTTPError): - async with sse_client( - "http://testserver/nonexistent", - headers={"Host": "testserver"}, - timeout=2, - client=test_client, - ): - pass # Should not reach here - - -@pytest.mark.anyio -async def test_sse_multiple_messages(test_client): - """Test sending and receiving multiple SSE messages.""" - async with sse_client( - "http://testserver/sse", - headers={"Host": "testserver"}, - timeout=2, - sse_read_timeout=1, - client=test_client, - ) as (read_stream, write_stream): - # Send multiple test messages - messages = [ - JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"}) - for i in range(3) - ] - - for msg in messages: - await write_stream.send(msg) - - # Receive all echoed messages - received = [] - async with read_stream: - for _ in range(len(messages)): - message = await read_stream.__anext__() - assert isinstance(message, JSONRPCMessage) - received.append(message) - - # Verify all messages were received in order - for sent, received in zip(messages, received): - assert sent.model_dump() == received.model_dump() \ No newline at end of file + async with anyio.create_task_group() as tg: + try: + async with sse_client( + "http://testserver/sse", + headers={"Host": "testserver"}, + timeout=5, + sse_read_timeout=5, + client=test_client, + ) as (read_stream, write_stream): + # First get the initial endpoint message + async with read_stream: + init_message = await read_stream.__anext__() + assert isinstance(init_message, JSONRPCMessage) + + # Send a test message + test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"}) + await write_stream.send(test_message) + + # Receive echoed message + async with read_stream: + message = await read_stream.__anext__() + assert isinstance(message, JSONRPCMessage) + assert message.model_dump() == test_message.model_dump() + + # Explicitly close streams + await write_stream.aclose() + await read_stream.aclose() + except Exception as e: + pytest.fail(f"Test failed with error: {str(e)}") + + +# @pytest.mark.anyio +# async def test_sse_read_timeout(test_client): +# """Test that SSE client properly handles read timeouts.""" +# with pytest.raises(ReadTimeout): +# async with sse_client( +# "http://testserver/sse", +# headers={"Host": "testserver"}, +# timeout=5, +# sse_read_timeout=2, +# client=test_client, +# ) as (read_stream, write_stream): +# async with read_stream: +# # This should timeout since no messages are being sent +# await read_stream.__anext__() + + +# @pytest.mark.anyio +# async def test_sse_connection_error(test_client): +# """Test SSE client behavior with connection errors.""" +# with pytest.raises(httpx.HTTPError): +# async with sse_client( +# "http://testserver/nonexistent", +# headers={"Host": "testserver"}, +# timeout=5, +# client=test_client, +# ): +# pass # Should not reach here + + +# @pytest.mark.anyio +# async def test_sse_multiple_messages(test_client): +# """Test sending and receiving multiple SSE messages.""" +# async with sse_client( +# "http://testserver/sse", +# headers={"Host": "testserver"}, +# timeout=5, +# sse_read_timeout=5, +# client=test_client, +# ) as (read_stream, write_stream): +# # Send multiple test messages +# messages = [ +# JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"}) +# for i in range(3) +# ] + +# for msg in messages: +# await write_stream.send(msg) + +# # Receive all echoed messages +# received = [] +# async with read_stream: +# for _ in range(len(messages)): +# message = await read_stream.__anext__() +# assert isinstance(message, JSONRPCMessage) +# received.append(message) + +# # Verify all messages were received in order +# for sent, received in zip(messages, received): +# assert sent.model_dump() == received.model_dump() From 66ccd1c515814ff4631b69c2bf0d1916aada91e8 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 10:18:11 -0500 Subject: [PATCH 05/16] test_sse_connection is passing --- tests/client/test_sse_attempt.py | 173 ---------------------------- tests/shared/test_sse.py | 188 +++++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+), 173 deletions(-) delete mode 100644 tests/client/test_sse_attempt.py create mode 100644 tests/shared/test_sse.py diff --git a/tests/client/test_sse_attempt.py b/tests/client/test_sse_attempt.py deleted file mode 100644 index 7d7329145..000000000 --- a/tests/client/test_sse_attempt.py +++ /dev/null @@ -1,173 +0,0 @@ -import anyio -import asyncio -import pytest -from starlette.applications import Starlette -from starlette.routing import Mount, Route -import httpx -from httpx import ReadTimeout, ASGITransport -from starlette.responses import Response -from sse_starlette.sse import EventSourceResponse - -from mcp.client.sse import sse_client -from mcp.server.sse import SseServerTransport -from mcp.types import JSONRPCMessage - - -@pytest.fixture -async def sse_transport(): - """Fixture that creates an SSE transport instance.""" - return SseServerTransport("/messages/") - - -@pytest.fixture -async def sse_app(sse_transport): - """Fixture that creates a Starlette app with SSE endpoints.""" - async def handle_sse(request): - """Handler for SSE connections.""" - async def event_generator(): - try: - async with sse_transport.connect_sse( - request.scope, request.receive, request._send - ) as streams: - client_to_server, server_to_client = streams - # Send initial connection event - yield { - "event": "endpoint", - "data": "/messages", - } - - # Process messages - async with anyio.create_task_group() as tg: - try: - async for message in client_to_server: - if isinstance(message, Exception): - break - yield { - "event": "message", - "data": message.model_dump_json(), - } - except (asyncio.CancelledError, GeneratorExit): - print('cancelled') - return - except Exception as e: - print("unhandled exception:", e) - return - except Exception: - # Log any unexpected errors but allow connection to close gracefully - pass - - return EventSourceResponse(event_generator()) - - routes = [ - Route("/sse", endpoint=handle_sse), - Mount("/messages", app=sse_transport.handle_post_message), - ] - - return Starlette(routes=routes) - - -@pytest.fixture -async def test_client(sse_app): - """Create a test client with ASGI transport.""" - transport = ASGITransport(app=sse_app) - async with httpx.AsyncClient( - transport=transport, - base_url="http://testserver", - timeout=10.0, - ) as client: - yield client - - -@pytest.mark.anyio -async def test_sse_connection(test_client): - """Test basic SSE connection and message exchange.""" - async with anyio.create_task_group() as tg: - try: - async with sse_client( - "http://testserver/sse", - headers={"Host": "testserver"}, - timeout=5, - sse_read_timeout=5, - client=test_client, - ) as (read_stream, write_stream): - # First get the initial endpoint message - async with read_stream: - init_message = await read_stream.__anext__() - assert isinstance(init_message, JSONRPCMessage) - - # Send a test message - test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"}) - await write_stream.send(test_message) - - # Receive echoed message - async with read_stream: - message = await read_stream.__anext__() - assert isinstance(message, JSONRPCMessage) - assert message.model_dump() == test_message.model_dump() - - # Explicitly close streams - await write_stream.aclose() - await read_stream.aclose() - except Exception as e: - pytest.fail(f"Test failed with error: {str(e)}") - - -# @pytest.mark.anyio -# async def test_sse_read_timeout(test_client): -# """Test that SSE client properly handles read timeouts.""" -# with pytest.raises(ReadTimeout): -# async with sse_client( -# "http://testserver/sse", -# headers={"Host": "testserver"}, -# timeout=5, -# sse_read_timeout=2, -# client=test_client, -# ) as (read_stream, write_stream): -# async with read_stream: -# # This should timeout since no messages are being sent -# await read_stream.__anext__() - - -# @pytest.mark.anyio -# async def test_sse_connection_error(test_client): -# """Test SSE client behavior with connection errors.""" -# with pytest.raises(httpx.HTTPError): -# async with sse_client( -# "http://testserver/nonexistent", -# headers={"Host": "testserver"}, -# timeout=5, -# client=test_client, -# ): -# pass # Should not reach here - - -# @pytest.mark.anyio -# async def test_sse_multiple_messages(test_client): -# """Test sending and receiving multiple SSE messages.""" -# async with sse_client( -# "http://testserver/sse", -# headers={"Host": "testserver"}, -# timeout=5, -# sse_read_timeout=5, -# client=test_client, -# ) as (read_stream, write_stream): -# # Send multiple test messages -# messages = [ -# JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"}) -# for i in range(3) -# ] - -# for msg in messages: -# await write_stream.send(msg) - -# # Receive all echoed messages -# received = [] -# async with read_stream: -# for _ in range(len(messages)): -# message = await read_stream.__anext__() -# assert isinstance(message, JSONRPCMessage) -# received.append(message) - -# # Verify all messages were received in order -# for sent, received in zip(messages, received): -# assert sent.model_dump() == received.model_dump() diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py new file mode 100644 index 000000000..07a859f59 --- /dev/null +++ b/tests/shared/test_sse.py @@ -0,0 +1,188 @@ +# test_sse.py +import re +import time +import json +import anyio +import pytest +import httpx +from typing import AsyncGenerator +from starlette.applications import Starlette +from starlette.routing import Mount, Route + +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from mcp.types import TextContent, Tool + +# Test server implementation +class TestServer(Server): + def __init__(self): + super().__init__("test_server") + + @self.list_tools() + async def handle_list_tools(): + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}} + ) + ] + + @self.call_tool() + async def handle_call_tool(name: str, args: dict): + return [TextContent(type="text", text=f"Called {name}")] + +import threading +import uvicorn +import pytest + + +# Test fixtures +@pytest.fixture +async def server_app()-> Starlette: + """Create test Starlette app with SSE transport""" + sse = SseServerTransport("/messages/") + server = TestServer() + + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await server.run( + streams[0], + streams[1], + server.create_initialization_options() + ) + + app = Starlette(routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ]) + + return app + +@pytest.fixture() +def server(server_app: Starlette): + server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=8765, log_level="error")) + server_thread = threading.Thread( target=server.run, daemon=True ) + print('starting server') + server_thread.start() + # Give server time to start + while not server.started: + print('waiting for server to start') + time.sleep(0.5) + yield + print('killing server') + server_thread.join(timeout=0.1) + +@pytest.fixture() +async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client""" + async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client: + yield client + +# Tests +@pytest.mark.anyio +async def test_sse_connection(client: httpx.AsyncClient): + """Test SSE connection establishment""" + async with anyio.create_task_group() as tg: + async def connection_test(): + async with client.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + return + line_number += 1 + + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() + +@pytest.mark.anyio +async def test_message_exchange(client: httpx.AsyncClient): + """Test full message exchange flow""" + # Connect to SSE endpoint + session_id = None + endpoint_url = None + + async with client.stream("GET", "/sse") as sse_response: + assert sse_response.status_code == 200 + + # Get endpoint URL and session ID + async for line in sse_response.aiter_lines(): + if line.startswith("data: "): + endpoint_url = json.loads(line[6:]) + session_id = endpoint_url.split("session_id=")[1] + break + + assert endpoint_url and session_id + + # Send initialize request + init_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test_client", + "version": "1.0" + } + } + } + + response = await client.post( + endpoint_url, + json=init_request + ) + assert response.status_code == 202 + + # Get initialize response from SSE stream + async for line in sse_response.aiter_lines(): + if line.startswith("event: message"): + data_line = next(sse_response.aiter_lines()) + response = json.loads(data_line[6:]) # Strip "data: " prefix + assert response["jsonrpc"] == "2.0" + assert response["id"] == 1 + assert "result" in response + break + +@pytest.mark.anyio +async def test_invalid_session(client: httpx.AsyncClient): + """Test sending message with invalid session ID""" + response = await client.post( + "/messages/?session_id=invalid", + json={"jsonrpc": "2.0", "method": "ping"} + ) + assert response.status_code == 400 + +@pytest.mark.anyio +async def test_connection_cleanup(server_app): + """Test that resources are cleaned up when client disconnects""" + sse = next( + route.app for route in server_app.routes + if isinstance(route, Mount) and route.path == "/messages/" + ).transport + + async with httpx.AsyncClient(app=server_app, base_url="http://test") as client: + # Connect and get session ID + async with client.stream("GET", "/sse") as response: + for line in response.iter_lines(): + if line.startswith("data: "): + endpoint_url = json.loads(line[6:]) + session_id = endpoint_url.split("session_id=")[1] + break + + assert len(sse._read_stream_writers) == 1 + + # After connection closes, writer should be cleaned up + await anyio.sleep(0.1) # Give cleanup a moment + assert len(sse._read_stream_writers) == 0 From e79a56435a9f6bfb84ffb5317501aceeda8ca48a Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 10:34:32 -0500 Subject: [PATCH 06/16] passing SSE client test --- tests/shared/test_sse.py | 133 +++++++++++++++------------------------ 1 file changed, 49 insertions(+), 84 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 07a859f59..ee3edb972 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -3,20 +3,36 @@ import time import json import anyio +from pydantic import AnyUrl +from pydantic_core import Url import pytest import httpx from typing import AsyncGenerator from starlette.applications import Starlette from starlette.routing import Mount, Route +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport -from mcp.types import TextContent, Tool +from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool + +SERVER_URL = "http://127.0.0.1:8765" +SERVER_SSE_URL = f"{SERVER_URL}/sse" + +SERVER_NAME = "test_server_for_SSE" # Test server implementation class TestServer(Server): def __init__(self): - super().__init__("test_server") + super().__init__(SERVER_NAME) + + @self.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + # TODO: make this an error + return "NOT FOUND" @self.list_tools() async def handle_list_tools(): @@ -76,18 +92,18 @@ def server(server_app: Starlette): server_thread.join(timeout=0.1) @pytest.fixture() -async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" - async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client: + async with httpx.AsyncClient(base_url=SERVER_URL) as client: yield client # Tests @pytest.mark.anyio -async def test_sse_connection(client: httpx.AsyncClient): - """Test SSE connection establishment""" +async def test_raw_sse_connection(http_client: httpx.AsyncClient): + """Test the SSE connection establishment simply with an HTTP client.""" async with anyio.create_task_group() as tg: async def connection_test(): - async with client.stream("GET", "/sse") as response: + async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 assert response.headers["content-type"] == "text/event-stream; charset=utf-8" @@ -105,84 +121,33 @@ async def connection_test(): with anyio.fail_after(3): await connection_test() -@pytest.mark.anyio -async def test_message_exchange(client: httpx.AsyncClient): - """Test full message exchange flow""" - # Connect to SSE endpoint - session_id = None - endpoint_url = None - - async with client.stream("GET", "/sse") as sse_response: - assert sse_response.status_code == 200 - - # Get endpoint URL and session ID - async for line in sse_response.aiter_lines(): - if line.startswith("data: "): - endpoint_url = json.loads(line[6:]) - session_id = endpoint_url.split("session_id=")[1] - break - - assert endpoint_url and session_id - - # Send initialize request - init_request = { - "jsonrpc": "2.0", - "id": 1, - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": { - "name": "test_client", - "version": "1.0" - } - } - } - - response = await client.post( - endpoint_url, - json=init_request - ) - assert response.status_code == 202 - - # Get initialize response from SSE stream - async for line in sse_response.aiter_lines(): - if line.startswith("event: message"): - data_line = next(sse_response.aiter_lines()) - response = json.loads(data_line[6:]) # Strip "data: " prefix - assert response["jsonrpc"] == "2.0" - assert response["id"] == 1 - assert "result" in response - break @pytest.mark.anyio -async def test_invalid_session(client: httpx.AsyncClient): - """Test sending message with invalid session ID""" - response = await client.post( - "/messages/?session_id=invalid", - json={"jsonrpc": "2.0", "method": "ping"} - ) - assert response.status_code == 400 +async def test_sse_client_basic_connection(server): + async with sse_client(SERVER_SSE_URL) as streams: + async with ClientSession(*streams) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Test ping + ping_result = await session.send_ping() + assert isinstance(ping_result, EmptyResult) + +@pytest.fixture +async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]: + async with sse_client(SERVER_SSE_URL) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session @pytest.mark.anyio -async def test_connection_cleanup(server_app): - """Test that resources are cleaned up when client disconnects""" - sse = next( - route.app for route in server_app.routes - if isinstance(route, Mount) and route.path == "/messages/" - ).transport - - async with httpx.AsyncClient(app=server_app, base_url="http://test") as client: - # Connect and get session ID - async with client.stream("GET", "/sse") as response: - for line in response.iter_lines(): - if line.startswith("data: "): - endpoint_url = json.loads(line[6:]) - session_id = endpoint_url.split("session_id=")[1] - break - - assert len(sse._read_stream_writers) == 1 - - # After connection closes, writer should be cleaned up - await anyio.sleep(0.1) # Give cleanup a moment - assert len(sse._read_stream_writers) == 0 +async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession): + session = initialized_sse_client_session + # TODO: expect raise + await session.read_resource(uri=AnyUrl("xxx://will-not-work")) + response = await session.read_resource(uri=AnyUrl("foobar://should-work")) + assert len(response.contents) == 1 + assert isinstance(response.contents[0], TextResourceContents) + assert response.contents[0].text == "Read should-work" From 8f81a85abe3e4ca3c15c8911c356e9875dd2b23c Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 10:45:55 -0500 Subject: [PATCH 07/16] all tests passing with custom port, but not passing all together --- tests/shared/test_sse.py | 56 ++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index ee3edb972..96c97581c 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -3,6 +3,9 @@ import time import json import anyio +import threading +import uvicorn +import pytest from pydantic import AnyUrl from pydantic_core import Url import pytest @@ -11,17 +14,29 @@ from starlette.applications import Starlette from starlette.routing import Mount, Route +from mcp.shared.exceptions import McpError from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport -from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool - -SERVER_URL = "http://127.0.0.1:8765" -SERVER_SSE_URL = f"{SERVER_URL}/sse" +from mcp.types import EmptyResult, ErrorData, InitializeResult, TextContent, TextResourceContents, Tool SERVER_NAME = "test_server_for_SSE" +@pytest.fixture +def server_port() -> int: + import socket + + s = socket.socket() + s.bind(('', 0)) + port = s.getsockname()[1] + s.close() + return port + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + # Test server implementation class TestServer(Server): def __init__(self): @@ -32,7 +47,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: if uri.scheme == "foobar": return f"Read {uri.host}" # TODO: make this an error - return "NOT FOUND" + raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) @self.list_tools() async def handle_list_tools(): @@ -48,9 +63,6 @@ async def handle_list_tools(): async def handle_call_tool(name: str, args: dict): return [TextContent(type="text", text=f"Called {name}")] -import threading -import uvicorn -import pytest # Test fixtures @@ -78,10 +90,10 @@ async def handle_sse(request): return app @pytest.fixture() -def server(server_app: Starlette): - server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=8765, log_level="error")) +def server(server_app: Starlette, server_port: int): + server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=server_port, log_level="error")) server_thread = threading.Thread( target=server.run, daemon=True ) - print('starting server') + print(f'starting server on {server_port}') server_thread.start() # Give server time to start while not server.started: @@ -92,9 +104,9 @@ def server(server_app: Starlette): server_thread.join(timeout=0.1) @pytest.fixture() -async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" - async with httpx.AsyncClient(base_url=SERVER_URL) as client: + async with httpx.AsyncClient(base_url=server_url) as client: yield client # Tests @@ -123,8 +135,8 @@ async def connection_test(): @pytest.mark.anyio -async def test_sse_client_basic_connection(server): - async with sse_client(SERVER_SSE_URL) as streams: +async def test_sse_client_basic_connection(server, server_url): + async with sse_client(server_url + "/sse") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -136,18 +148,22 @@ async def test_sse_client_basic_connection(server): assert isinstance(ping_result, EmptyResult) @pytest.fixture -async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]: - async with sse_client(SERVER_SSE_URL) as streams: +async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server_url + "/sse") as streams: async with ClientSession(*streams) as session: await session.initialize() yield session @pytest.mark.anyio -async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession): +async def test_sse_client_happy_request_and_response(initialized_sse_client_session: ClientSession): session = initialized_sse_client_session - # TODO: expect raise - await session.read_resource(uri=AnyUrl("xxx://will-not-work")) response = await session.read_resource(uri=AnyUrl("foobar://should-work")) assert len(response.contents) == 1 assert isinstance(response.contents[0], TextResourceContents) assert response.contents[0].text == "Read should-work" + +@pytest.mark.anyio +async def test_sse_client_exception_handling(initialized_sse_client_session: ClientSession): + session = initialized_sse_client_session + with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): + await session.read_resource(uri=AnyUrl("xxx://will-not-work")) From 7ab1fc71aa9447bf63fa837e9fad049d00d9a488 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:01:43 -0500 Subject: [PATCH 08/16] attempt to get server to shut down --- tests/shared/test_sse.py | 44 +++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 96c97581c..bab41a8b9 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,5 +1,6 @@ # test_sse.py import re +import socket import time import json import anyio @@ -25,13 +26,9 @@ @pytest.fixture def server_port() -> int: - import socket - - s = socket.socket() - s.bind(('', 0)) - port = s.getsockname()[1] - s.close() - return port + with socket.socket() as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] @pytest.fixture def server_url(server_port: int) -> str: @@ -89,6 +86,12 @@ async def handle_sse(request): return app +@pytest.fixture(autouse=True) +def space_around_test(): + time.sleep(0.1) + yield + time.sleep(0.1) + @pytest.fixture() def server(server_app: Starlette, server_port: int): server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=server_port, log_level="error")) @@ -99,9 +102,27 @@ def server(server_app: Starlette, server_port: int): while not server.started: print('waiting for server to start') time.sleep(0.5) - yield - print('killing server') - server_thread.join(timeout=0.1) + + try: + yield + finally: + print('killing server') + # Signal the server to stop + server.should_exit = True + + # Force close the server's main socket + if hasattr(server.servers, "servers"): + for s in server.servers: + print(f'closing {s}') + s.close() + + # Wait for thread to finish + server_thread.join(timeout=2) + if server_thread.is_alive(): + print("Warning: Server thread did not exit cleanly") + # Optionally, you could add more aggressive cleanup here + import _thread + _thread.interrupt_main() @pytest.fixture() async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: @@ -167,3 +188,6 @@ async def test_sse_client_exception_handling(initialized_sse_client_session: Cli session = initialized_sse_client_session with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri=AnyUrl("xxx://will-not-work")) + + +# TODO: test that timeouts are respected and that the error comes back From 8d90a3afa3fa6bbcb3310dfef82aba9f23e3e8e5 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:15:05 -0500 Subject: [PATCH 09/16] attempt at server in process --- tests/shared/test_sse.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index bab41a8b9..03396f67e 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,5 +1,6 @@ # test_sse.py import re +import multiprocessing import socket import time import json @@ -94,10 +95,15 @@ def space_around_test(): @pytest.fixture() def server(server_app: Starlette, server_port: int): - server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=server_port, log_level="error")) - server_thread = threading.Thread( target=server.run, daemon=True ) + proc = multiprocessing.Process(target=uvicorn.run, daemon=True, kwargs={ + "app": server_app, + "host": "127.0.0.1", + "port": server_port, + "log_level": "error" + }) print(f'starting server on {server_port}') - server_thread.start() + proc.start() + # Give server time to start while not server.started: print('waiting for server to start') @@ -117,8 +123,9 @@ def server(server_app: Starlette, server_port: int): s.close() # Wait for thread to finish - server_thread.join(timeout=2) - if server_thread.is_alive(): + proc.terminate() + proc.join(timeout=2) + if proc.is_alive(): print("Warning: Server thread did not exit cleanly") # Optionally, you could add more aggressive cleanup here import _thread From 7b35ab844aea95b8529262aa13ef03ab7456822f Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:44:03 -0500 Subject: [PATCH 10/16] run server in separate process --- tests/shared/test_sse.py | 66 +++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 03396f67e..30c15ac9c 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -64,8 +64,7 @@ async def handle_call_tool(name: str, args: dict): # Test fixtures -@pytest.fixture -async def server_app()-> Starlette: +def make_server_app()-> Starlette: """Create test Starlette app with SSE transport""" sse = SseServerTransport("/messages/") server = TestServer() @@ -93,43 +92,46 @@ def space_around_test(): yield time.sleep(0.1) -@pytest.fixture() -def server(server_app: Starlette, server_port: int): - proc = multiprocessing.Process(target=uvicorn.run, daemon=True, kwargs={ - "app": server_app, - "host": "127.0.0.1", - "port": server_port, - "log_level": "error" - }) +def run_server(server_port: int): + app = make_server_app() + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f'starting server on {server_port}') - proc.start() + server.run() # Give server time to start while not server.started: print('waiting for server to start') time.sleep(0.5) - try: - yield - finally: - print('killing server') - # Signal the server to stop - server.should_exit = True - - # Force close the server's main socket - if hasattr(server.servers, "servers"): - for s in server.servers: - print(f'closing {s}') - s.close() - - # Wait for thread to finish - proc.terminate() - proc.join(timeout=2) - if proc.is_alive(): - print("Warning: Server thread did not exit cleanly") - # Optionally, you could add more aggressive cleanup here - import _thread - _thread.interrupt_main() +@pytest.fixture() +def server(server_port: int): + proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) + print('starting process') + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print('waiting for server to start') + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(('127.0.0.1', server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError("Server failed to start after {} attempts".format(max_attempts)) + + yield + + print('killing server') + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") @pytest.fixture() async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: From 5097bb7ef8892a73d559f3e6de76bb9593985de8 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:47:00 -0500 Subject: [PATCH 11/16] revert unintended changes --- src/mcp/client/sse.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index e09f6c5bf..abafacb96 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -24,20 +24,12 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, - client: httpx.AsyncClient | None = None, ): """ Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. - - Args: - url: The URL to connect to - headers: Optional headers to send with the request - timeout: Connection timeout in seconds - sse_read_timeout: Read timeout in seconds - client: Optional httpx.AsyncClient instance to use for requests """ read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] @@ -51,13 +43,7 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - if client is None: - client = httpx.AsyncClient(headers=headers) - should_close_client = True - else: - should_close_client = False - - try: + async with httpx.AsyncClient(headers=headers) as client: async with aconnect_sse( client, "GET", @@ -151,9 +137,6 @@ async def post_writer(endpoint_url: str): yield read_stream, write_stream finally: tg.cancel_scope.cancel() - finally: - if should_close_client: - await client.aclose() finally: await read_stream_writer.aclose() await write_stream.aclose() From 07e721f63f0f45b40e3fa9cdbf631b380ac815fb Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:47:41 -0500 Subject: [PATCH 12/16] formatting --- tests/shared/test_sse.py | 92 ++++++++++++++++++++++++++++------------ 1 file changed, 65 insertions(+), 27 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 30c15ac9c..8f9221a8e 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,4 +1,3 @@ -# test_sse.py import re import multiprocessing import socket @@ -21,20 +20,30 @@ from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport -from mcp.types import EmptyResult, ErrorData, InitializeResult, TextContent, TextResourceContents, Tool +from mcp.types import ( + EmptyResult, + ErrorData, + InitializeResult, + TextContent, + TextResourceContents, + Tool, +) SERVER_NAME = "test_server_for_SSE" + @pytest.fixture def server_port() -> int: with socket.socket() as s: - s.bind(('127.0.0.1', 0)) + s.bind(("127.0.0.1", 0)) return s.getsockname()[1] + @pytest.fixture def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" + # Test server implementation class TestServer(Server): def __init__(self): @@ -45,7 +54,11 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: if uri.scheme == "foobar": return f"Read {uri.host}" # TODO: make this an error - raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) + raise McpError( + error=ErrorData( + code=404, message="OOPS! no resource with that URI was found" + ) + ) @self.list_tools() async def handle_list_tools(): @@ -53,7 +66,7 @@ async def handle_list_tools(): Tool( name="test_tool", description="A test tool", - inputSchema={"type": "object", "properties": {}} + inputSchema={"type": "object", "properties": {}}, ) ] @@ -62,9 +75,8 @@ async def handle_call_tool(name: str, args: dict): return [TextContent(type="text", text=f"Called {name}")] - # Test fixtures -def make_server_app()-> Starlette: +def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" sse = SseServerTransport("/messages/") server = TestServer() @@ -74,80 +86,97 @@ async def handle_sse(request): request.scope, request.receive, request._send ) as streams: await server.run( - streams[0], - streams[1], - server.create_initialization_options() + streams[0], streams[1], server.create_initialization_options() ) - app = Starlette(routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ]) + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ] + ) return app + @pytest.fixture(autouse=True) def space_around_test(): time.sleep(0.1) yield time.sleep(0.1) + def run_server(server_port: int): app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f'starting server on {server_port}') + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting server on {server_port}") server.run() # Give server time to start while not server.started: - print('waiting for server to start') + print("waiting for server to start") time.sleep(0.5) + @pytest.fixture() def server(server_port: int): - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print('starting process') + proc = multiprocessing.Process( + target=run_server, kwargs={"server_port": server_port}, daemon=True + ) + print("starting process") proc.start() # Wait for server to be running max_attempts = 20 attempt = 0 - print('waiting for server to start') + print("waiting for server to start") while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(('127.0.0.1', server_port)) + s.connect(("127.0.0.1", server_port)) break except ConnectionRefusedError: time.sleep(0.1) attempt += 1 else: - raise RuntimeError("Server failed to start after {} attempts".format(max_attempts)) + raise RuntimeError( + "Server failed to start after {} attempts".format(max_attempts) + ) yield - print('killing server') + print("killing server") # Signal the server to stop proc.kill() proc.join(timeout=2) if proc.is_alive(): print("server process failed to terminate") + @pytest.fixture() async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" async with httpx.AsyncClient(base_url=server_url) as client: yield client + # Tests @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient): """Test the SSE connection establishment simply with an HTTP client.""" async with anyio.create_task_group() as tg: + async def connection_test(): async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert ( + response.headers["content-type"] + == "text/event-stream; charset=utf-8" + ) line_number = 0 async for line in response.aiter_lines(): @@ -177,23 +206,32 @@ async def test_sse_client_basic_connection(server, server_url): ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) + @pytest.fixture -async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session( + server, server_url: str +) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse") as streams: async with ClientSession(*streams) as session: await session.initialize() yield session + @pytest.mark.anyio -async def test_sse_client_happy_request_and_response(initialized_sse_client_session: ClientSession): +async def test_sse_client_happy_request_and_response( + initialized_sse_client_session: ClientSession, +): session = initialized_sse_client_session response = await session.read_resource(uri=AnyUrl("foobar://should-work")) assert len(response.contents) == 1 assert isinstance(response.contents[0], TextResourceContents) assert response.contents[0].text == "Read should-work" + @pytest.mark.anyio -async def test_sse_client_exception_handling(initialized_sse_client_session: ClientSession): +async def test_sse_client_exception_handling( + initialized_sse_client_session: ClientSession, +): session = initialized_sse_client_session with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri=AnyUrl("xxx://will-not-work")) From 3fa26a5a9771ab34e06c3011545a5b374f0758f8 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:49:04 -0500 Subject: [PATCH 13/16] remove unused imports --- tests/shared/test_sse.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 8f9221a8e..aad558f49 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,15 +1,10 @@ -import re import multiprocessing import socket import time -import json import anyio -import threading import uvicorn import pytest from pydantic import AnyUrl -from pydantic_core import Url -import pytest import httpx from typing import AsyncGenerator from starlette.applications import Starlette From aa7869a62fb13f94718a684e942c9ce7efbe8e25 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:54:09 -0500 Subject: [PATCH 14/16] add type hints --- tests/shared/test_sse.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index aad558f49..b6e9af7f8 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -2,11 +2,12 @@ import socket import time import anyio +from starlette.requests import Request import uvicorn import pytest from pydantic import AnyUrl import httpx -from typing import AsyncGenerator +from typing import AsyncGenerator, Generator from starlette.applications import Starlette from starlette.routing import Mount, Route @@ -56,7 +57,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: ) @self.list_tools() - async def handle_list_tools(): + async def handle_list_tools() -> list[Tool]: return [ Tool( name="test_tool", @@ -66,7 +67,7 @@ async def handle_list_tools(): ] @self.call_tool() - async def handle_call_tool(name: str, args: dict): + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] @@ -76,7 +77,7 @@ def make_server_app() -> Starlette: sse = SseServerTransport("/messages/") server = TestServer() - async def handle_sse(request): + async def handle_sse(request: Request) -> None: async with sse.connect_sse( request.scope, request.receive, request._send ) as streams: @@ -94,14 +95,7 @@ async def handle_sse(request): return app -@pytest.fixture(autouse=True) -def space_around_test(): - time.sleep(0.1) - yield - time.sleep(0.1) - - -def run_server(server_port: int): +def run_server(server_port: int) -> None: app = make_server_app() server = uvicorn.Server( config=uvicorn.Config( @@ -118,7 +112,7 @@ def run_server(server_port: int): @pytest.fixture() -def server(server_port: int): +def server(server_port: int) -> Generator[None, None, None]: proc = multiprocessing.Process( target=run_server, kwargs={"server_port": server_port}, daemon=True ) @@ -161,11 +155,11 @@ async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, N # Tests @pytest.mark.anyio -async def test_raw_sse_connection(http_client: httpx.AsyncClient): +async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: """Test the SSE connection establishment simply with an HTTP client.""" async with anyio.create_task_group() as tg: - async def connection_test(): + async def connection_test() -> None: async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 assert ( @@ -189,7 +183,7 @@ async def connection_test(): @pytest.mark.anyio -async def test_sse_client_basic_connection(server, server_url): +async def test_sse_client_basic_connection(server: None, server_url: str) -> None: async with sse_client(server_url + "/sse") as streams: async with ClientSession(*streams) as session: # Test initialization @@ -215,7 +209,7 @@ async def initialized_sse_client_session( @pytest.mark.anyio async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, -): +) -> None: session = initialized_sse_client_session response = await session.read_resource(uri=AnyUrl("foobar://should-work")) assert len(response.contents) == 1 @@ -226,7 +220,7 @@ async def test_sse_client_happy_request_and_response( @pytest.mark.anyio async def test_sse_client_exception_handling( initialized_sse_client_session: ClientSession, -): +) -> None: session = initialized_sse_client_session with pytest.raises(McpError, match="OOPS! no resource with that URI was found"): await session.read_resource(uri=AnyUrl("xxx://will-not-work")) From e798d20cbbbb09fdca5e2c2eb38ed4fe4989273d Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 11:56:44 -0500 Subject: [PATCH 15/16] ruff --- tests/shared/test_sse.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index b6e9af7f8..28aaee427 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,21 +1,22 @@ import multiprocessing import socket import time +from typing import AsyncGenerator, Generator + import anyio -from starlette.requests import Request -import uvicorn +import httpx import pytest +import uvicorn from pydantic import AnyUrl -import httpx -from typing import AsyncGenerator, Generator from starlette.applications import Starlette +from starlette.requests import Request from starlette.routing import Mount, Route -from mcp.shared.exceptions import McpError from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport +from mcp.shared.exceptions import McpError from mcp.types import ( EmptyResult, ErrorData, @@ -157,8 +158,7 @@ async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, N @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: """Test the SSE connection establishment simply with an HTTP client.""" - async with anyio.create_task_group() as tg: - + async with anyio.create_task_group(): async def connection_test() -> None: async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 From d01d49ea6e046f0a691a84fe0c4e5a6583450e07 Mon Sep 17 00:00:00 2001 From: Nick Merrill Date: Tue, 14 Jan 2025 12:07:14 -0500 Subject: [PATCH 16/16] add timeout test --- tests/shared/test_sse.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 28aaee427..9d32fff34 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -21,6 +21,7 @@ EmptyResult, ErrorData, InitializeResult, + ReadResourceResult, TextContent, TextResourceContents, Tool, @@ -50,7 +51,11 @@ def __init__(self): async def handle_read_resource(uri: AnyUrl) -> str | bytes: if uri.scheme == "foobar": return f"Read {uri.host}" - # TODO: make this an error + elif uri.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return f"Slow response from {uri.host}" + raise McpError( error=ErrorData( code=404, message="OOPS! no resource with that URI was found" @@ -200,12 +205,13 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non async def initialized_sse_client_session( server, server_url: str ) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse") as streams: + async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() yield session + @pytest.mark.anyio async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, @@ -226,4 +232,23 @@ async def test_sse_client_exception_handling( await session.read_resource(uri=AnyUrl("xxx://will-not-work")) -# TODO: test that timeouts are respected and that the error comes back +@pytest.mark.anyio +@pytest.mark.skip( + "this test highlights a possible bug in SSE read timeout exception handling" +) +async def test_sse_client_timeout( + initialized_sse_client_session: ClientSession, +) -> None: + session = initialized_sse_client_session + + # sanity check that normal, fast responses are working + response = await session.read_resource(uri=AnyUrl("foobar://1")) + assert isinstance(response, ReadResourceResult) + + with anyio.move_on_after(3): + with pytest.raises(McpError, match="Read timed out"): + response = await session.read_resource(uri=AnyUrl("slow://2")) + # we should receive an error here + return + + pytest.fail("the client should have timed out and returned an error already")