77import httpx
88from anyio .abc import TaskStatus
99from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10+ from exceptiongroup import ExceptionGroup , catch
1011from httpx_sse import aconnect_sse
1112
1213import mcp .types as types
@@ -18,6 +19,14 @@ def remove_request_params(url: str) -> str:
1819 return urljoin (url , urlparse (url ).path )
1920
2021
22+ def handle_exception (exc : Exception ) -> str :
23+ """Handle ExceptionGroup and Exceptions for Client transport for SSE"""
24+ if isinstance (exc , ExceptionGroup ):
25+ messages = "; " .join (str (e ) for e in exc .exceptions )
26+ raise Exception (f"TaskGroup failed with: { messages } " ) from None
27+ else :
28+ raise Exception (f"TaskGroup failed with: { exc } " ) from None
29+
2130@asynccontextmanager
2231async def sse_client (
2332 url : str ,
@@ -40,115 +49,115 @@ async def sse_client(
4049 read_stream_writer , read_stream = anyio .create_memory_object_stream (0 )
4150 write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
4251
43- errors : list [Exception ] = []
44-
45- async with anyio .create_task_group () as tg :
46- try :
47- logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
48- async with httpx .AsyncClient (headers = headers ) as client :
49- async with aconnect_sse (
50- client ,
51- "GET" ,
52- url ,
53- timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
54- ) as event_source :
55- event_source .response .raise_for_status ()
56- logger .debug ("SSE connection established" )
57-
58- async def sse_reader (
59- task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
60- ):
61- try :
62- async for sse in event_source .aiter_sse ():
63- logger .debug (f"Received SSE event: { sse .event } " )
64- match sse .event :
65- case "endpoint" :
66- endpoint_url = urljoin (url , sse .data )
67- logger .info (
68- f"Received endpoint URL: { endpoint_url } "
69- )
70-
71- url_parsed = urlparse (url )
72- endpoint_parsed = urlparse (endpoint_url )
73- if (
74- url_parsed .netloc != endpoint_parsed .netloc
75- or url_parsed .scheme
76- != endpoint_parsed .scheme
77- ):
78- error_msg = (
79- "Endpoint origin does not match "
80- f"connection origin: { endpoint_url } "
52+ with catch ({
53+ Exception : handle_exception ,
54+ }):
55+ async with anyio .create_task_group () as tg :
56+ try :
57+ logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
58+ async with httpx .AsyncClient (headers = headers ) as client :
59+ async with aconnect_sse (
60+ client ,
61+ "GET" ,
62+ url ,
63+ timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
64+ ) as event_source :
65+ event_source .response .raise_for_status ()
66+ logger .debug ("SSE connection established" )
67+
68+ async def sse_reader (
69+ task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
70+ ):
71+ try :
72+ async for sse in event_source .aiter_sse ():
73+ logger .debug (f"Received SSE event: { sse .event } " )
74+ match sse .event :
75+ case "endpoint" :
76+ endpoint_url = urljoin (url , sse .data )
77+ logger .info (
78+ f"Received endpoint URL: { endpoint_url } "
8179 )
82- logger .error (error_msg )
83- raise ValueError (error_msg )
84-
85- task_status .started (endpoint_url )
8680
87- case "message" :
88- try :
89- message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
90- sse .data
91- )
92- logger .debug (
93- f"Received server message: { message } "
81+ url_parsed = urlparse (url )
82+ endpoint_parsed = urlparse (endpoint_url )
83+ if (
84+ url_parsed .netloc
85+ != endpoint_parsed .netloc
86+ or url_parsed .scheme
87+ != endpoint_parsed .scheme
88+ ):
89+ error_msg = (
90+ "Endpoint origin does not match "
91+ f"connection origin: { endpoint_url } "
92+ )
93+ logger .error (error_msg )
94+ raise ValueError (error_msg )
95+
96+ task_status .started (endpoint_url )
97+
98+ case "message" :
99+ try :
100+ message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
101+ sse .data
102+ )
103+ logger .debug (
104+ f"Received server message: "
105+ f"{ message } "
106+ )
107+ except Exception as exc :
108+ logger .error (
109+ f"Error parsing server message: "
110+ f"{ exc } "
111+ )
112+ await read_stream_writer .send (exc )
113+ continue
114+
115+ await read_stream_writer .send (message )
116+ case _:
117+ logger .warning (
118+ f"Unknown SSE event: { sse .event } "
94119 )
95- except Exception as exc :
96- logger .error (
97- f"Error parsing server message: { exc } "
98- )
99- await read_stream_writer .send (exc )
100- continue
101-
102- await read_stream_writer .send (message )
103- case _:
104- logger .warning (
105- f"Unknown SSE event: { sse .event } "
120+ except Exception as exc :
121+ logger .error (f"Error in sse_reader: { exc } " )
122+ raise
123+ finally :
124+ await read_stream_writer .aclose ()
125+
126+ async def post_writer (endpoint_url : str ):
127+ try :
128+ async with write_stream_reader :
129+ async for message in write_stream_reader :
130+ logger .debug (
131+ f"Sending client message: { message } "
106132 )
107- except Exception as exc :
108- logger .error (f"Error in sse_reader: { exc } " )
109- raise
110- finally :
111- await read_stream_writer .aclose ()
133+ response = await client .post (
134+ endpoint_url ,
135+ json = message .model_dump (
136+ by_alias = True ,
137+ mode = "json" ,
138+ exclude_none = True ,
139+ ),
140+ )
141+ response .raise_for_status ()
142+ logger .debug (
143+ "Client message sent successfully: "
144+ f"{ response .status_code } "
145+ )
146+ except Exception as exc :
147+ logger .error (f"Error in post_writer: { exc } " )
148+ finally :
149+ await write_stream .aclose ()
150+
151+ endpoint_url = await tg .start (sse_reader )
152+ logger .info (
153+ f"Starting post writer with endpoint URL: { endpoint_url } "
154+ )
155+ tg .start_soon (post_writer , endpoint_url )
112156
113- async def post_writer (endpoint_url : str ):
114157 try :
115- async with write_stream_reader :
116- async for message in write_stream_reader :
117- logger .debug (f"Sending client message: { message } " )
118- response = await client .post (
119- endpoint_url ,
120- json = message .model_dump (
121- by_alias = True ,
122- mode = "json" ,
123- exclude_none = True ,
124- ),
125- )
126- response .raise_for_status ()
127- logger .debug (
128- "Client message sent successfully: "
129- f"{ response .status_code } "
130- )
131- except Exception as exc :
132- logger .error (f"Error in post_writer: { exc } " )
158+ yield read_stream , write_stream
133159 finally :
134- await write_stream .aclose ()
135-
136- endpoint_url = await tg .start (sse_reader )
137- logger .info (
138- f"Starting post writer with endpoint URL: { endpoint_url } "
139- )
140- tg .start_soon (post_writer , endpoint_url )
141-
142- try :
143- yield read_stream , write_stream
144- finally :
145- tg .cancel_scope .cancel ()
146- except* ValueError as eg :
147- errors .extend (eg .exceptions )
148- except* Exception as eg :
149- errors .extend (eg .exceptions )
150- finally :
151- await read_stream_writer .aclose ()
152- await write_stream .aclose ()
153- if errors :
154- raise Exception ("TaskGroup failed with: " + " " .join ([str (e ) for e in errors ]))
160+ tg .cancel_scope .cancel ()
161+ finally :
162+ await read_stream_writer .aclose ()
163+ await write_stream .aclose ()
0 commit comments