33import time
44import json
55import anyio
6+ from pydantic import AnyUrl
7+ from pydantic_core import Url
68import pytest
79import httpx
810from typing import AsyncGenerator
911from starlette .applications import Starlette
1012from starlette .routing import Mount , Route
1113
14+ from mcp .client .session import ClientSession
15+ from mcp .client .sse import sse_client
1216from mcp .server import Server
1317from mcp .server .sse import SseServerTransport
14- from mcp .types import TextContent , Tool
18+ from mcp .types import EmptyResult , InitializeResult , TextContent , TextResourceContents , Tool
19+
20+ SERVER_URL = "http://127.0.0.1:8765"
21+ SERVER_SSE_URL = f"{ SERVER_URL } /sse"
22+
23+ SERVER_NAME = "test_server_for_SSE"
1524
1625# Test server implementation
1726class TestServer (Server ):
1827 def __init__ (self ):
19- super ().__init__ ("test_server" )
28+ super ().__init__ (SERVER_NAME )
29+
30+ @self .read_resource ()
31+ async def handle_read_resource (uri : AnyUrl ) -> str | bytes :
32+ if uri .scheme == "foobar" :
33+ return f"Read { uri .host } "
34+ # TODO: make this an error
35+ return "NOT FOUND"
2036
2137 @self .list_tools ()
2238 async def handle_list_tools ():
@@ -76,18 +92,18 @@ def server(server_app: Starlette):
7692 server_thread .join (timeout = 0.1 )
7793
7894@pytest .fixture ()
79- async def client (server ) -> AsyncGenerator [httpx .AsyncClient , None ]:
95+ async def http_client (server ) -> AsyncGenerator [httpx .AsyncClient , None ]:
8096 """Create test client"""
81- async with httpx .AsyncClient (base_url = "http://127.0.0.1:8765" ) as client :
97+ async with httpx .AsyncClient (base_url = SERVER_URL ) as client :
8298 yield client
8399
84100# Tests
85101@pytest .mark .anyio
86- async def test_sse_connection ( client : httpx .AsyncClient ):
87- """Test SSE connection establishment"""
102+ async def test_raw_sse_connection ( http_client : httpx .AsyncClient ):
103+ """Test the SSE connection establishment simply with an HTTP client. """
88104 async with anyio .create_task_group () as tg :
89105 async def connection_test ():
90- async with client .stream ("GET" , "/sse" ) as response :
106+ async with http_client .stream ("GET" , "/sse" ) as response :
91107 assert response .status_code == 200
92108 assert response .headers ["content-type" ] == "text/event-stream; charset=utf-8"
93109
@@ -105,84 +121,33 @@ async def connection_test():
105121 with anyio .fail_after (3 ):
106122 await connection_test ()
107123
108- @pytest .mark .anyio
109- async def test_message_exchange (client : httpx .AsyncClient ):
110- """Test full message exchange flow"""
111- # Connect to SSE endpoint
112- session_id = None
113- endpoint_url = None
114-
115- async with client .stream ("GET" , "/sse" ) as sse_response :
116- assert sse_response .status_code == 200
117-
118- # Get endpoint URL and session ID
119- async for line in sse_response .aiter_lines ():
120- if line .startswith ("data: " ):
121- endpoint_url = json .loads (line [6 :])
122- session_id = endpoint_url .split ("session_id=" )[1 ]
123- break
124-
125- assert endpoint_url and session_id
126-
127- # Send initialize request
128- init_request = {
129- "jsonrpc" : "2.0" ,
130- "id" : 1 ,
131- "method" : "initialize" ,
132- "params" : {
133- "protocolVersion" : "2024-11-05" ,
134- "capabilities" : {},
135- "clientInfo" : {
136- "name" : "test_client" ,
137- "version" : "1.0"
138- }
139- }
140- }
141-
142- response = await client .post (
143- endpoint_url ,
144- json = init_request
145- )
146- assert response .status_code == 202
147-
148- # Get initialize response from SSE stream
149- async for line in sse_response .aiter_lines ():
150- if line .startswith ("event: message" ):
151- data_line = next (sse_response .aiter_lines ())
152- response = json .loads (data_line [6 :]) # Strip "data: " prefix
153- assert response ["jsonrpc" ] == "2.0"
154- assert response ["id" ] == 1
155- assert "result" in response
156- break
157124
158125@pytest .mark .anyio
159- async def test_invalid_session (client : httpx .AsyncClient ):
160- """Test sending message with invalid session ID"""
161- response = await client .post (
162- "/messages/?session_id=invalid" ,
163- json = {"jsonrpc" : "2.0" , "method" : "ping" }
164- )
165- assert response .status_code == 400
126+ async def test_sse_client_basic_connection (server ):
127+ async with sse_client (SERVER_SSE_URL ) as streams :
128+ async with ClientSession (* streams ) as session :
129+ # Test initialization
130+ result = await session .initialize ()
131+ assert isinstance (result , InitializeResult )
132+ assert result .serverInfo .name == SERVER_NAME
133+
134+ # Test ping
135+ ping_result = await session .send_ping ()
136+ assert isinstance (ping_result , EmptyResult )
137+
138+ @pytest .fixture
139+ async def initialized_sse_client_session (server ) -> AsyncGenerator [ClientSession , None ]:
140+ async with sse_client (SERVER_SSE_URL ) as streams :
141+ async with ClientSession (* streams ) as session :
142+ await session .initialize ()
143+ yield session
166144
167145@pytest .mark .anyio
168- async def test_connection_cleanup (server_app ):
169- """Test that resources are cleaned up when client disconnects"""
170- sse = next (
171- route .app for route in server_app .routes
172- if isinstance (route , Mount ) and route .path == "/messages/"
173- ).transport
174-
175- async with httpx .AsyncClient (app = server_app , base_url = "http://test" ) as client :
176- # Connect and get session ID
177- async with client .stream ("GET" , "/sse" ) as response :
178- for line in response .iter_lines ():
179- if line .startswith ("data: " ):
180- endpoint_url = json .loads (line [6 :])
181- session_id = endpoint_url .split ("session_id=" )[1 ]
182- break
183-
184- assert len (sse ._read_stream_writers ) == 1
185-
186- # After connection closes, writer should be cleaned up
187- await anyio .sleep (0.1 ) # Give cleanup a moment
188- assert len (sse ._read_stream_writers ) == 0
146+ async def test_sse_client_request_and_response (initialized_sse_client_session : ClientSession ):
147+ session = initialized_sse_client_session
148+ # TODO: expect raise
149+ await session .read_resource (uri = AnyUrl ("xxx://will-not-work" ))
150+ response = await session .read_resource (uri = AnyUrl ("foobar://should-work" ))
151+ assert len (response .contents ) == 1
152+ assert isinstance (response .contents [0 ], TextResourceContents )
153+ assert response .contents [0 ].text == "Read should-work"
0 commit comments