1414 RawMessage ,
1515 logger ,
1616)
17+ from coagent .core .messages import Cancel
1718from coagent .core .exceptions import BaseError
1819from coagent .core .factory import DeleteAgent
1920from coagent .core .types import Runtime
@@ -137,8 +138,9 @@ async def publish(self, request: Request):
137138 msg = RawMessage .decode (data ["msg" ])
138139 await self ._update_message_header_extensions (msg , request )
139140
141+ addr = Address .decode (data ["addr" ])
140142 resp : RawMessage | None = await self ._runtime .channel .publish (
141- addr = Address . decode ( data [ " addr" ]) ,
143+ addr = addr ,
142144 msg = msg ,
143145 request = data .get ("request" , False ),
144146 reply = data .get ("reply" , "" ),
@@ -147,6 +149,11 @@ async def publish(self, request: Request):
147149 )
148150 except BaseError as exc :
149151 return JSONResponse (exc .encode (mode = "json" ), status_code = 404 )
152+ except asyncio .CancelledError :
153+ # Disconnected from the client.
154+
155+ # Cancel the ongoing operation.
156+ await self ._runtime .channel .publish (addr , Cancel ().encode ())
150157
151158 if resp is None :
152159 return Response (status_code = 204 )
@@ -158,8 +165,9 @@ async def publish_multi(self, request: Request):
158165 msg = RawMessage .decode (data ["msg" ])
159166 await self ._update_message_header_extensions (msg , request )
160167
168+ addr = Address .decode (data ["addr" ])
161169 msgs = self ._runtime .channel .publish_multi (
162- addr = Address . decode ( data [ " addr" ]) ,
170+ addr = addr ,
163171 msg = msg ,
164172 probe = data .get ("probe" , True ),
165173 )
@@ -170,11 +178,16 @@ async def event_stream() -> AsyncIterator[str]:
170178 yield dict (data = raw .encode_json ())
171179 except BaseError as exc :
172180 yield dict (event = "error" , data = exc .encode_json ())
181+ except asyncio .CancelledError :
182+ # Disconnected from the client.
183+
184+ # Cancel the ongoing operation.
185+ await self ._runtime .channel .publish (addr , Cancel ().encode ())
173186
174187 return EventSourceResponse (event_stream ())
175188
176189 async def _update_message_header_extensions (
177190 self , msg : RawMessage , request : Request
178191 ) -> None :
179- """Update the message header extensions according to the data from the request."""
192+ """Update the message header extensions according to the request data ."""
180193 pass
0 commit comments