44from starlette .routing import Mount , Route
55import httpx
66from httpx import ReadTimeout , ASGITransport
7+ from starlette .responses import Response
8+ from sse_starlette .sse import EventSourceResponse
79
810from mcp .client .sse import sse_client
911from mcp .server .sse import SseServerTransport
@@ -21,17 +23,33 @@ async def sse_app(sse_transport):
2123 """Fixture that creates a Starlette app with SSE endpoints."""
2224 async def handle_sse (request ):
2325 """Handler for SSE connections."""
24- async with sse_transport .connect_sse (
25- request .scope , request .receive , request ._send
26- ) as streams :
27- client_to_server , server_to_client = streams
28- async for message in client_to_server :
29- # Echo messages back for testing
30- await server_to_client .send (message )
26+ async def event_generator ():
27+ # Send initial connection event
28+ yield {
29+ "event" : "endpoint" ,
30+ "data" : "/messages" ,
31+ }
32+
33+ # Keep connection alive
34+ async with sse_transport .connect_sse (
35+ request .scope , request .receive , request ._send
36+ ) as streams :
37+ client_to_server , server_to_client = streams
38+ async for message in client_to_server :
39+ yield {
40+ "event" : "message" ,
41+ "data" : message .model_dump_json (),
42+ }
43+
44+ return EventSourceResponse (event_generator ())
45+
46+ async def handle_post (request ):
47+ """Handler for POST messages."""
48+ return Response (status_code = 200 )
3149
3250 routes = [
3351 Route ("/sse" , endpoint = handle_sse ),
34- Mount ("/messages" , app = sse_transport . handle_post_message ),
52+ Route ("/messages" , endpoint = handle_post , methods = [ "POST" ] ),
3553 ]
3654
3755 return Starlette (routes = routes )
@@ -40,9 +58,11 @@ async def handle_sse(request):
4058@pytest .fixture
4159async def test_client (sse_app ):
4260 """Create a test client with ASGI transport."""
61+ transport = ASGITransport (app = sse_app )
4362 async with httpx .AsyncClient (
44- transport = ASGITransport ( app = sse_app ) ,
63+ transport = transport ,
4564 base_url = "http://testserver" ,
65+ timeout = 5.0 ,
4666 ) as client :
4767 yield client
4868
@@ -53,7 +73,8 @@ async def test_sse_connection(test_client):
5373 async with sse_client (
5474 "http://testserver/sse" ,
5575 headers = {"Host" : "testserver" },
56- timeout = 5 ,
76+ timeout = 2 ,
77+ sse_read_timeout = 1 ,
5778 client = test_client ,
5879 ) as (read_stream , write_stream ):
5980 # Send a test message
@@ -74,7 +95,7 @@ async def test_sse_read_timeout(test_client):
7495 async with sse_client (
7596 "http://testserver/sse" ,
7697 headers = {"Host" : "testserver" },
77- timeout = 5 ,
98+ timeout = 2 ,
7899 sse_read_timeout = 1 ,
79100 client = test_client ,
80101 ) as (read_stream , write_stream ):
@@ -90,7 +111,7 @@ async def test_sse_connection_error(test_client):
90111 async with sse_client (
91112 "http://testserver/nonexistent" ,
92113 headers = {"Host" : "testserver" },
93- timeout = 5 ,
114+ timeout = 2 ,
94115 client = test_client ,
95116 ):
96117 pass # Should not reach here
@@ -102,7 +123,8 @@ async def test_sse_multiple_messages(test_client):
102123 async with sse_client (
103124 "http://testserver/sse" ,
104125 headers = {"Host" : "testserver" },
105- timeout = 5 ,
126+ timeout = 2 ,
127+ sse_read_timeout = 1 ,
106128 client = test_client ,
107129 ) as (read_stream , write_stream ):
108130 # Send multiple test messages
0 commit comments