@@ -47,9 +47,7 @@ async def sse_client(
4747 read_stream_writer , read_stream = anyio .create_memory_object_stream (0 )
4848 write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
4949
50- with catch ({
51- Exception : handle_exception
52- }):
50+ with catch ({Exception : handle_exception }):
5351 async with anyio .create_task_group () as tg :
5452 try :
5553 logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
@@ -99,72 +97,46 @@ async def sse_reader(
9997 sse .data
10098 )
10199 logger .debug (
102- f "Received server message: "
100+ "Received server message: "
103101 f"{ message } "
102+
104103 )
105104 except Exception as exc :
106105 logger .error (
107- f "Error parsing server message: "
106+ "Error parsing server message: "
108107 f"{ exc } "
109108 )
110109 await read_stream_writer .send (exc )
111110 continue
112111
113- await read_stream_writer .send (message )
112+ session_message = SessionMessage (message )
113+ await read_stream_writer .send (
114+ session_message
115+ )
114116 case _:
115117 logger .warning (
116118 f"Unknown SSE event: { sse .event } "
117119 )
118120 except Exception as exc :
119121 logger .error (f"Error in sse_reader: { exc } " )
120- raise
122+ await read_stream_writer . send ( exc )
121123 finally :
122124 await read_stream_writer .aclose ()
123125
124126 async def post_writer (endpoint_url : str ):
125127 try :
126128 async with write_stream_reader :
127- async for message in write_stream_reader :
129+ async for session_message in write_stream_reader :
128130 logger .debug (
129- f"Sending client message: { message } "
131+ f"Sending client message: { session_message } "
130132 )
131-
132- url_parsed = urlparse (url )
133- endpoint_parsed = urlparse (endpoint_url )
134- if (
135- url_parsed .netloc != endpoint_parsed .netloc
136- or url_parsed .scheme
137- != endpoint_parsed .scheme
138- ):
139- error_msg = (
140- "Endpoint origin does not match "
141- f"connection origin: { endpoint_url } "
142- )
143- logger .error (error_msg )
144- raise ValueError (error_msg )
145-
146- task_status .started (endpoint_url )
147-
148- case "message" :
149- try :
150- message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
151- sse .data
152- )
153- logger .debug (
154- f"Received server message: { message } "
155- )
156- except Exception as exc :
157- logger .error (
158- f"Error parsing server message: { exc } "
159- )
160- await read_stream_writer .send (exc )
161- continue
162-
163- session_message = SessionMessage (message )
164- await read_stream_writer .send (session_message )
165- case _ :
166- logger .warning (
167- f"Unknown SSE event: { sse .event } "
133+ response = await client .post (
134+ endpoint_url ,
135+ json = session_message .message .model_dump (
136+ by_alias = True ,
137+ mode = "json" ,
138+ exclude_none = True ,
139+ ),
168140 )
169141 response .raise_for_status ()
170142 logger .debug (
@@ -183,26 +155,7 @@ async def post_writer(endpoint_url: str):
183155 tg .start_soon (post_writer , endpoint_url )
184156
185157 try :
186- async with write_stream_reader :
187- async for session_message in write_stream_reader :
188- logger .debug (
189- f"Sending client message: { session_message } "
190- )
191- response = await client .post (
192- endpoint_url ,
193- json = session_message .message .model_dump (
194- by_alias = True ,
195- mode = "json" ,
196- exclude_none = True ,
197- ),
198- )
199- response .raise_for_status ()
200- logger .debug (
201- "Client message sent successfully: "
202- f"{ response .status_code } "
203- )
204- except Exception as exc :
205- logger .error (f"Error in post_writer: { exc } " )
158+ yield read_stream , write_stream
206159 finally :
207160 tg .cancel_scope .cancel ()
208161 finally :
0 commit comments