Skip to content

Commit e9bb70c

Browse files
committed
Cancel the agent's processing once the user disconnects
1 parent 357840f commit e9bb70c

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

coagent/cos/runtime.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
RawMessage,
1515
logger,
1616
)
17+
from coagent.core.messages import Cancel
1718
from coagent.core.exceptions import BaseError
1819
from coagent.core.factory import DeleteAgent
1920
from coagent.core.types import Runtime
@@ -137,8 +138,9 @@ async def publish(self, request: Request):
137138
msg = RawMessage.decode(data["msg"])
138139
await self._update_message_header_extensions(msg, request)
139140

141+
addr = Address.decode(data["addr"])
140142
resp: RawMessage | None = await self._runtime.channel.publish(
141-
addr=Address.decode(data["addr"]),
143+
addr=addr,
142144
msg=msg,
143145
request=data.get("request", False),
144146
reply=data.get("reply", ""),
@@ -147,6 +149,11 @@ async def publish(self, request: Request):
147149
)
148150
except BaseError as exc:
149151
return JSONResponse(exc.encode(mode="json"), status_code=404)
152+
except asyncio.CancelledError:
153+
# Disconnected from the client.
154+
155+
# Cancel the ongoing operation.
156+
await self._runtime.channel.publish(addr, Cancel().encode())
150157

151158
if resp is None:
152159
return Response(status_code=204)
@@ -158,8 +165,9 @@ async def publish_multi(self, request: Request):
158165
msg = RawMessage.decode(data["msg"])
159166
await self._update_message_header_extensions(msg, request)
160167

168+
addr = Address.decode(data["addr"])
161169
msgs = self._runtime.channel.publish_multi(
162-
addr=Address.decode(data["addr"]),
170+
addr=addr,
163171
msg=msg,
164172
probe=data.get("probe", True),
165173
)
@@ -170,11 +178,16 @@ async def event_stream() -> AsyncIterator[str]:
170178
yield dict(data=raw.encode_json())
171179
except BaseError as exc:
172180
yield dict(event="error", data=exc.encode_json())
181+
except asyncio.CancelledError:
182+
# Disconnected from the client.
183+
184+
# Cancel the ongoing operation.
185+
await self._runtime.channel.publish(addr, Cancel().encode())
173186

174187
return EventSourceResponse(event_stream())
175188

176189
async def _update_message_header_extensions(
177190
self, msg: RawMessage, request: Request
178191
) -> None:
179-
"""Update the message header extensions according to the data from the request."""
192+
"""Update the message header extensions according to the request data."""
180193
pass

0 commit comments

Comments
 (0)