Skip to content

Commit e79a564

Browse files
committed
passing SSE client test
1 parent 66ccd1c commit e79a564

File tree

1 file changed

+49
-84
lines changed

1 file changed

+49
-84
lines changed

tests/shared/test_sse.py

Lines changed: 49 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,36 @@
33
import time
44
import json
55
import anyio
6+
from pydantic import AnyUrl
7+
from pydantic_core import Url
68
import pytest
79
import httpx
810
from typing import AsyncGenerator
911
from starlette.applications import Starlette
1012
from starlette.routing import Mount, Route
1113

14+
from mcp.client.session import ClientSession
15+
from mcp.client.sse import sse_client
1216
from mcp.server import Server
1317
from mcp.server.sse import SseServerTransport
14-
from mcp.types import TextContent, Tool
18+
from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool
19+
20+
SERVER_URL = "http://127.0.0.1:8765"
21+
SERVER_SSE_URL = f"{SERVER_URL}/sse"
22+
23+
SERVER_NAME = "test_server_for_SSE"
1524

1625
# Test server implementation
1726
class TestServer(Server):
1827
def __init__(self):
19-
super().__init__("test_server")
28+
super().__init__(SERVER_NAME)
29+
30+
@self.read_resource()
31+
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
32+
if uri.scheme == "foobar":
33+
return f"Read {uri.host}"
34+
# TODO: make this an error
35+
return "NOT FOUND"
2036

2137
@self.list_tools()
2238
async def handle_list_tools():
@@ -76,18 +92,18 @@ def server(server_app: Starlette):
7692
server_thread.join(timeout=0.1)
7793

7894
@pytest.fixture()
79-
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
95+
async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
8096
"""Create test client"""
81-
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client:
97+
async with httpx.AsyncClient(base_url=SERVER_URL) as client:
8298
yield client
8399

84100
# Tests
85101
@pytest.mark.anyio
86-
async def test_sse_connection(client: httpx.AsyncClient):
87-
"""Test SSE connection establishment"""
102+
async def test_raw_sse_connection(http_client: httpx.AsyncClient):
103+
"""Test the SSE connection establishment simply with an HTTP client."""
88104
async with anyio.create_task_group() as tg:
89105
async def connection_test():
90-
async with client.stream("GET", "/sse") as response:
106+
async with http_client.stream("GET", "/sse") as response:
91107
assert response.status_code == 200
92108
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
93109

@@ -105,84 +121,33 @@ async def connection_test():
105121
with anyio.fail_after(3):
106122
await connection_test()
107123

108-
@pytest.mark.anyio
109-
async def test_message_exchange(client: httpx.AsyncClient):
110-
"""Test full message exchange flow"""
111-
# Connect to SSE endpoint
112-
session_id = None
113-
endpoint_url = None
114-
115-
async with client.stream("GET", "/sse") as sse_response:
116-
assert sse_response.status_code == 200
117-
118-
# Get endpoint URL and session ID
119-
async for line in sse_response.aiter_lines():
120-
if line.startswith("data: "):
121-
endpoint_url = json.loads(line[6:])
122-
session_id = endpoint_url.split("session_id=")[1]
123-
break
124-
125-
assert endpoint_url and session_id
126-
127-
# Send initialize request
128-
init_request = {
129-
"jsonrpc": "2.0",
130-
"id": 1,
131-
"method": "initialize",
132-
"params": {
133-
"protocolVersion": "2024-11-05",
134-
"capabilities": {},
135-
"clientInfo": {
136-
"name": "test_client",
137-
"version": "1.0"
138-
}
139-
}
140-
}
141-
142-
response = await client.post(
143-
endpoint_url,
144-
json=init_request
145-
)
146-
assert response.status_code == 202
147-
148-
# Get initialize response from SSE stream
149-
async for line in sse_response.aiter_lines():
150-
if line.startswith("event: message"):
151-
data_line = next(sse_response.aiter_lines())
152-
response = json.loads(data_line[6:]) # Strip "data: " prefix
153-
assert response["jsonrpc"] == "2.0"
154-
assert response["id"] == 1
155-
assert "result" in response
156-
break
157124

158125
@pytest.mark.anyio
159-
async def test_invalid_session(client: httpx.AsyncClient):
160-
"""Test sending message with invalid session ID"""
161-
response = await client.post(
162-
"/messages/?session_id=invalid",
163-
json={"jsonrpc": "2.0", "method": "ping"}
164-
)
165-
assert response.status_code == 400
126+
async def test_sse_client_basic_connection(server):
127+
async with sse_client(SERVER_SSE_URL) as streams:
128+
async with ClientSession(*streams) as session:
129+
# Test initialization
130+
result = await session.initialize()
131+
assert isinstance(result, InitializeResult)
132+
assert result.serverInfo.name == SERVER_NAME
133+
134+
# Test ping
135+
ping_result = await session.send_ping()
136+
assert isinstance(ping_result, EmptyResult)
137+
138+
@pytest.fixture
139+
async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]:
140+
async with sse_client(SERVER_SSE_URL) as streams:
141+
async with ClientSession(*streams) as session:
142+
await session.initialize()
143+
yield session
166144

167145
@pytest.mark.anyio
168-
async def test_connection_cleanup(server_app):
169-
"""Test that resources are cleaned up when client disconnects"""
170-
sse = next(
171-
route.app for route in server_app.routes
172-
if isinstance(route, Mount) and route.path == "/messages/"
173-
).transport
174-
175-
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client:
176-
# Connect and get session ID
177-
async with client.stream("GET", "/sse") as response:
178-
for line in response.iter_lines():
179-
if line.startswith("data: "):
180-
endpoint_url = json.loads(line[6:])
181-
session_id = endpoint_url.split("session_id=")[1]
182-
break
183-
184-
assert len(sse._read_stream_writers) == 1
185-
186-
# After connection closes, writer should be cleaned up
187-
await anyio.sleep(0.1) # Give cleanup a moment
188-
assert len(sse._read_stream_writers) == 0
146+
async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession):
147+
session = initialized_sse_client_session
148+
# TODO: expect raise
149+
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
150+
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
151+
assert len(response.contents) == 1
152+
assert isinstance(response.contents[0], TextResourceContents)
153+
assert response.contents[0].text == "Read should-work"

0 commit comments

Comments
 (0)