33
44import pytest
55
6- from coagent .core .types import Address
6+ from coagent .core .types import Address , Agent , Channel , RawMessage
77from coagent .core .agent import BaseAgent , Context , handler
88from coagent .core .exceptions import BaseError
99from coagent .core .messages import Cancel , Message
@@ -43,6 +43,22 @@ async def handle(self, msg: Query, ctx: Context) -> AsyncIterator[Reply]:
4343 yield Reply ()
4444
4545
46+ class _TestFactory :
47+ def __init__ (self , channel : Channel , address : Address ):
48+ self .channel = channel
49+ self .address = address
50+
51+ self .agent = None
52+ self .sub = None
53+
54+ async def receive (self , msg : RawMessage ) -> None :
55+ await self .agent .stop ()
56+
57+ async def start (self , agent : Agent ) -> None :
58+ self .agent = agent
59+ self .sub = await self .channel .subscribe (self .address , self .receive )
60+
61+
4662class TestTrivialAgent :
4763 @pytest .mark .asyncio
4864 async def test_normal (self , local_channel , run_agent_in_task , yield_control ):
@@ -60,9 +76,13 @@ async def test_normal(self, local_channel, run_agent_in_task, yield_control):
6076
6177 @pytest .mark .asyncio
6278 async def test_cancel (self , local_channel , run_agent_in_task , yield_control ):
79+ test_factory = _TestFactory (local_channel , Address (name = "test_1" ))
80+
6381 agent = TrivialAgent (wait_s = 10 )
6482 addr = Address (name = "test" , id = "1" )
65- agent .init (local_channel , addr )
83+ agent .init (local_channel , addr , test_factory .address )
84+
85+ await test_factory .start (agent )
6686
6787 _task = run_agent_in_task (agent )
6888 await yield_control ()
@@ -97,9 +117,13 @@ async def test_normal(self, local_channel, run_agent_in_task, yield_control):
97117
98118 @pytest .mark .asyncio
99119 async def test_cancel (self , local_channel , run_agent_in_task , yield_control ):
120+ test_factory = _TestFactory (local_channel , Address (name = "test_3" ))
121+
100122 agent = StreamAgent (wait_s = 10 )
101123 addr = Address (name = "test" , id = "3" )
102- agent .init (local_channel , addr )
124+ agent .init (local_channel , addr , test_factory .address )
125+
126+ await test_factory .start (agent )
103127
104128 _task = run_agent_in_task (agent )
105129 await yield_control ()
0 commit comments