Skip to content

Commit ddee4b7

Browse files
committed
Fix agent deletion
1 parent b35ff6b commit ddee4b7

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

coagent/core/agent.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ async def delete(self) -> None:
196196

197197
if self.factory_address:
198198
msg = DeleteAgent(session_id=self.address.id).encode()
199-
await self.channel.publish(self.factory_address, msg)
199+
await self.channel.publish(self.factory_address, msg, probe=False)
200200

201201
async def started(self) -> None:
202202
"""This handler is called after the agent is started."""
@@ -243,8 +243,6 @@ async def _handle_control(self, msg: ControlMessage) -> None:
243243
"""Handle CONTROL messages."""
244244
match msg:
245245
case Cancel():
246-
if self._handle_data_task:
247-
self._handle_data_task.cancel()
248246
# Delete the agent when cancelled.
249247
await self.delete()
250248

coagent/core/factory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ async def start(self) -> None:
6363
await super().start()
6464

6565
# Generate a unique address and create an instance subscription.
66-
self._instance_address = Address(name=self.address.name, id=uuid.uuid4().hex)
67-
self._instance_sub = self.channel.subscribe(
66+
unique_id = uuid.uuid4().hex
67+
self._instance_address = Address(name=f"{self.address.name}_{unique_id}")
68+
self._instance_sub = await self.channel.subscribe(
6869
self._instance_address, handler=self.receive
6970
)
7071

tests/core/test_agent.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from coagent.core.types import Address
6+
from coagent.core.types import Address, Agent, Channel, RawMessage
77
from coagent.core.agent import BaseAgent, Context, handler
88
from coagent.core.exceptions import BaseError
99
from coagent.core.messages import Cancel, Message
@@ -43,6 +43,22 @@ async def handle(self, msg: Query, ctx: Context) -> AsyncIterator[Reply]:
4343
yield Reply()
4444

4545

46+
class _TestFactory:
47+
def __init__(self, channel: Channel, address: Address):
48+
self.channel = channel
49+
self.address = address
50+
51+
self.agent = None
52+
self.sub = None
53+
54+
async def receive(self, msg: RawMessage) -> None:
55+
await self.agent.stop()
56+
57+
async def start(self, agent: Agent) -> None:
58+
self.agent = agent
59+
self.sub = await self.channel.subscribe(self.address, self.receive)
60+
61+
4662
class TestTrivialAgent:
4763
@pytest.mark.asyncio
4864
async def test_normal(self, local_channel, run_agent_in_task, yield_control):
@@ -60,9 +76,13 @@ async def test_normal(self, local_channel, run_agent_in_task, yield_control):
6076

6177
@pytest.mark.asyncio
6278
async def test_cancel(self, local_channel, run_agent_in_task, yield_control):
79+
test_factory = _TestFactory(local_channel, Address(name="test_1"))
80+
6381
agent = TrivialAgent(wait_s=10)
6482
addr = Address(name="test", id="1")
65-
agent.init(local_channel, addr)
83+
agent.init(local_channel, addr, test_factory.address)
84+
85+
await test_factory.start(agent)
6686

6787
_task = run_agent_in_task(agent)
6888
await yield_control()
@@ -97,9 +117,13 @@ async def test_normal(self, local_channel, run_agent_in_task, yield_control):
97117

98118
@pytest.mark.asyncio
99119
async def test_cancel(self, local_channel, run_agent_in_task, yield_control):
120+
test_factory = _TestFactory(local_channel, Address(name="test_3"))
121+
100122
agent = StreamAgent(wait_s=10)
101123
addr = Address(name="test", id="3")
102-
agent.init(local_channel, addr)
124+
agent.init(local_channel, addr, test_factory.address)
125+
126+
await test_factory.start(agent)
103127

104128
_task = run_agent_in_task(agent)
105129
await yield_control()

0 commit comments

Comments
 (0)