Skip to content

Commit 66ccd1c

Browse files
committed
test_sse_connection is passing
1 parent a0e2f7f commit 66ccd1c

File tree

2 files changed

+188
-173
lines changed

2 files changed

+188
-173
lines changed

tests/client/test_sse_attempt.py

Lines changed: 0 additions & 173 deletions
This file was deleted.

tests/shared/test_sse.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# test_sse.py
2+
import re
3+
import time
4+
import json
5+
import anyio
6+
import pytest
7+
import httpx
8+
from typing import AsyncGenerator
9+
from starlette.applications import Starlette
10+
from starlette.routing import Mount, Route
11+
12+
from mcp.server import Server
13+
from mcp.server.sse import SseServerTransport
14+
from mcp.types import TextContent, Tool
15+
16+
# Test server implementation
17+
class TestServer(Server):
18+
def __init__(self):
19+
super().__init__("test_server")
20+
21+
@self.list_tools()
22+
async def handle_list_tools():
23+
return [
24+
Tool(
25+
name="test_tool",
26+
description="A test tool",
27+
inputSchema={"type": "object", "properties": {}}
28+
)
29+
]
30+
31+
@self.call_tool()
32+
async def handle_call_tool(name: str, args: dict):
33+
return [TextContent(type="text", text=f"Called {name}")]
34+
35+
import threading
36+
import uvicorn
37+
import pytest
38+
39+
40+
# Test fixtures
41+
@pytest.fixture
42+
async def server_app()-> Starlette:
43+
"""Create test Starlette app with SSE transport"""
44+
sse = SseServerTransport("/messages/")
45+
server = TestServer()
46+
47+
async def handle_sse(request):
48+
async with sse.connect_sse(
49+
request.scope, request.receive, request._send
50+
) as streams:
51+
await server.run(
52+
streams[0],
53+
streams[1],
54+
server.create_initialization_options()
55+
)
56+
57+
app = Starlette(routes=[
58+
Route("/sse", endpoint=handle_sse),
59+
Mount("/messages/", app=sse.handle_post_message),
60+
])
61+
62+
return app
63+
64+
@pytest.fixture()
65+
def server(server_app: Starlette):
66+
server = uvicorn.Server(config=uvicorn.Config(app=server_app, host="127.0.0.1", port=8765, log_level="error"))
67+
server_thread = threading.Thread( target=server.run, daemon=True )
68+
print('starting server')
69+
server_thread.start()
70+
# Give server time to start
71+
while not server.started:
72+
print('waiting for server to start')
73+
time.sleep(0.5)
74+
yield
75+
print('killing server')
76+
server_thread.join(timeout=0.1)
77+
78+
@pytest.fixture()
79+
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
80+
"""Create test client"""
81+
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client:
82+
yield client
83+
84+
# Tests
85+
@pytest.mark.anyio
86+
async def test_sse_connection(client: httpx.AsyncClient):
87+
"""Test SSE connection establishment"""
88+
async with anyio.create_task_group() as tg:
89+
async def connection_test():
90+
async with client.stream("GET", "/sse") as response:
91+
assert response.status_code == 200
92+
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
93+
94+
line_number = 0
95+
async for line in response.aiter_lines():
96+
if line_number == 0:
97+
assert line == "event: endpoint"
98+
elif line_number == 1:
99+
assert line.startswith("data: /messages/?session_id=")
100+
else:
101+
return
102+
line_number += 1
103+
104+
# Add timeout to prevent test from hanging if it fails
105+
with anyio.fail_after(3):
106+
await connection_test()
107+
108+
@pytest.mark.anyio
109+
async def test_message_exchange(client: httpx.AsyncClient):
110+
"""Test full message exchange flow"""
111+
# Connect to SSE endpoint
112+
session_id = None
113+
endpoint_url = None
114+
115+
async with client.stream("GET", "/sse") as sse_response:
116+
assert sse_response.status_code == 200
117+
118+
# Get endpoint URL and session ID
119+
async for line in sse_response.aiter_lines():
120+
if line.startswith("data: "):
121+
endpoint_url = json.loads(line[6:])
122+
session_id = endpoint_url.split("session_id=")[1]
123+
break
124+
125+
assert endpoint_url and session_id
126+
127+
# Send initialize request
128+
init_request = {
129+
"jsonrpc": "2.0",
130+
"id": 1,
131+
"method": "initialize",
132+
"params": {
133+
"protocolVersion": "2024-11-05",
134+
"capabilities": {},
135+
"clientInfo": {
136+
"name": "test_client",
137+
"version": "1.0"
138+
}
139+
}
140+
}
141+
142+
response = await client.post(
143+
endpoint_url,
144+
json=init_request
145+
)
146+
assert response.status_code == 202
147+
148+
# Get initialize response from SSE stream
149+
async for line in sse_response.aiter_lines():
150+
if line.startswith("event: message"):
151+
data_line = next(sse_response.aiter_lines())
152+
response = json.loads(data_line[6:]) # Strip "data: " prefix
153+
assert response["jsonrpc"] == "2.0"
154+
assert response["id"] == 1
155+
assert "result" in response
156+
break
157+
158+
@pytest.mark.anyio
159+
async def test_invalid_session(client: httpx.AsyncClient):
160+
"""Test sending message with invalid session ID"""
161+
response = await client.post(
162+
"/messages/?session_id=invalid",
163+
json={"jsonrpc": "2.0", "method": "ping"}
164+
)
165+
assert response.status_code == 400
166+
167+
@pytest.mark.anyio
168+
async def test_connection_cleanup(server_app):
169+
"""Test that resources are cleaned up when client disconnects"""
170+
sse = next(
171+
route.app for route in server_app.routes
172+
if isinstance(route, Mount) and route.path == "/messages/"
173+
).transport
174+
175+
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client:
176+
# Connect and get session ID
177+
async with client.stream("GET", "/sse") as response:
178+
for line in response.iter_lines():
179+
if line.startswith("data: "):
180+
endpoint_url = json.loads(line[6:])
181+
session_id = endpoint_url.split("session_id=")[1]
182+
break
183+
184+
assert len(sse._read_stream_writers) == 1
185+
186+
# After connection closes, writer should be cleaned up
187+
await anyio.sleep(0.1) # Give cleanup a moment
188+
assert len(sse._read_stream_writers) == 0

0 commit comments

Comments
 (0)