Skip to content

Commit 47c7542

Browse files
committed
Improve Channel.publish() to support both streaming and non-streaming
1 parent 0460324 commit 47c7542

File tree

12 files changed

+87
-100
lines changed

12 files changed

+87
-100
lines changed

coagent/agents/chat_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(self, host_agent: ChatAgent, agent_type: str):
125125

126126
async def handle(self, msg: ChatHistory) -> AsyncIterator[ChatMessage]:
127127
addr = Address(name=self.agent_type, id=self.host_agent.address.id)
128-
result = self.host_agent.channel.publish_multi(addr, msg.encode())
128+
result = await self.host_agent.channel.publish(addr, msg.encode(), stream=True)
129129
full_content = ""
130130
async for chunk in result:
131131
resp = ChatMessage.decode(chunk)

coagent/cli/main.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,18 @@ async def run(
6767
async with runtime:
6868
addr = Address(name=agent_type, id=session_id)
6969
try:
70+
response = await runtime.channel.publish(
71+
addr,
72+
msg,
73+
stream=stream,
74+
request=True,
75+
timeout=10,
76+
probe=probe,
77+
)
7078
if not stream:
71-
response = await runtime.channel.publish(
72-
addr,
73-
msg,
74-
request=True,
75-
timeout=10,
76-
probe=probe,
77-
)
7879
print_msg(response, oneline, filter)
7980
else:
80-
async for chunk in runtime.channel.publish_multi(addr, msg):
81+
async for chunk in response:
8182
print_msg(chunk, oneline, filter)
8283
except asyncio.CancelledError:
8384
await runtime.channel.cancel(addr)

coagent/core/runtime.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import abc
12
import asyncio
23
from typing import AsyncIterator
34

@@ -71,19 +72,63 @@ def channel(self) -> Channel:
7172

7273

7374
class BaseChannel(Channel):
74-
async def publish_multi(
75+
async def publish(
76+
self,
77+
addr: Address,
78+
msg: RawMessage,
79+
stream: bool = False,
80+
request: bool = False,
81+
reply: str = "",
82+
timeout: float = 0.5,
83+
probe: bool = True,
84+
) -> AsyncIterator[RawMessage] | RawMessage | None:
85+
if stream:
86+
return self._publish_stream(addr, msg, probe=probe)
87+
else:
88+
return await self._publish(
89+
addr,
90+
msg,
91+
request=request,
92+
stream=stream,
93+
reply=reply,
94+
timeout=timeout,
95+
probe=probe,
96+
)
97+
98+
@abc.abstractmethod
99+
async def _publish(
100+
self,
101+
addr: Address,
102+
msg: RawMessage,
103+
request: bool = False,
104+
stream: bool = False,
105+
reply: str = "",
106+
timeout: float = 0.5,
107+
probe: bool = True,
108+
) -> RawMessage | None:
109+
pass
110+
111+
async def _publish_stream(
75112
self,
76113
addr: Address,
77114
msg: RawMessage,
78115
probe: bool = True,
79116
) -> AsyncIterator[RawMessage]:
80-
"""A default implementation that leverages the channel's own subscribe and publish methods."""
117+
"""Publish a message and wait for multiple reply messages.
118+
119+
Args:
120+
addr (Address): The address of the agent.
121+
msg (RawMessage): The raw message to send.
122+
probe (bool, optional): Whether to probe the agent before sending the message. Defaults to True.
123+
124+
This is a default implementation that leverages the channel's own subscribe and _publish methods.
125+
"""
81126
queue: QueueSubscriptionIterator = QueueSubscriptionIterator()
82127

83128
inbox = await self.new_reply_topic()
84129
sub = await self.subscribe(addr=Address(name=inbox), handler=queue.receive)
85130

86-
await self.publish(
131+
await self._publish(
87132
addr,
88133
msg,
89134
request=True,

coagent/core/types.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -236,41 +236,25 @@ async def publish(
236236
self,
237237
addr: Address,
238238
msg: RawMessage,
239-
request: bool = False,
240239
stream: bool = False,
240+
request: bool = False,
241241
reply: str = "",
242242
timeout: float = 0.5,
243243
probe: bool = True,
244-
) -> RawMessage | None:
245-
"""Publish a message.
244+
) -> AsyncIterator[RawMessage] | RawMessage | None:
245+
"""Publish a message to the given address.
246246
247247
Args:
248248
addr (Address): The address of the agent.
249249
msg (RawMessage): The raw message to send.
250-
request (bool, optional): Whether this is a request. Defaults to False.
251250
stream (bool, optional): Whether to request a streaming result. Defaults to False.
251+
request (bool, optional): Whether this is a request. Defaults to False. If `stream` is True, then this is always True.
252252
reply (str, optional): If `request` is True, then this will be the subject to reply to. Defaults to "".
253253
timeout (float, optional): If `request` is True, then this will be the timeout for the response. Defaults to 0.5.
254254
probe (bool, optional): Whether to probe the agent before sending the message. Defaults to True.
255255
"""
256256
pass
257257

258-
@abc.abstractmethod
259-
async def publish_multi(
260-
self,
261-
addr: Address,
262-
msg: RawMessage,
263-
probe: bool = True,
264-
) -> AsyncIterator[RawMessage]:
265-
"""Publish a message and wait for multiple reply messages.
266-
267-
Args:
268-
addr (Address): The address of the agent.
269-
msg (RawMessage): The raw message to send.
270-
probe (bool, optional): Whether to probe the agent before sending the message. Defaults to True.
271-
"""
272-
pass
273-
274258
@abc.abstractmethod
275259
async def subscribe(
276260
self,
@@ -310,20 +294,17 @@ async def run(
310294
stream: bool = False,
311295
session_id: str = "",
312296
timeout: float = 0.5,
313-
) -> RawMessage | AsyncIterator[RawMessage]:
297+
) -> AsyncIterator[RawMessage] | RawMessage | None:
314298
"""Create an agent and run it with the given message."""
315299
if self.__runtime is None:
316300
raise ValueError(f"AgentSpec {self.name} is not registered to a runtime.")
317301

318302
session_id = session_id or uuid.uuid4().hex
319303
addr = Address(name=self.name, id=session_id)
320304

321-
if stream:
322-
return self.__runtime.channel.publish_multi(addr, msg)
323-
else:
324-
return await self.__runtime.channel.publish(
325-
addr, msg, request=True, timeout=timeout
326-
)
305+
return await self.__runtime.channel.publish(
306+
addr, msg, stream=stream, request=True, timeout=timeout
307+
)
327308

328309

329310
class Runtime(abc.ABC):

coagent/cos/runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,10 @@ async def publish_multi(self, request: Request):
161161
await self._update_message_header_extensions(msg, request)
162162

163163
addr = Address.decode(data["addr"])
164-
msgs = self._runtime.channel.publish_multi(
164+
msgs = await self._runtime.channel.publish(
165165
addr=addr,
166166
msg=msg,
167+
stream=True,
167168
probe=data.get("probe", True),
168169
)
169170

coagent/runtimes/http_runtime.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def from_server(cls, server: str, auth: str = "") -> HTTPRuntime:
4343
class HTTPChannel(BaseChannel):
4444
"""An HTTP-based channel.
4545
46-
publish: POST /publish
47-
publish_multi: POST /publish_multi
46+
_publish: POST /publish
47+
_publish_stream: POST /publish_multi
4848
subscribe: POST /subscribe
4949
new_reply_topic: POST /reply-topics
5050
"""
@@ -64,7 +64,7 @@ async def connect(self) -> None:
6464
async def close(self) -> None:
6565
pass
6666

67-
async def publish(
67+
async def _publish(
6868
self,
6969
addr: Address,
7070
msg: RawMessage,
@@ -95,7 +95,7 @@ async def publish(
9595
if resp.is_error:
9696
raise_http_error(resp, resp.text)
9797

98-
async def publish_multi(
98+
async def _publish_stream(
9999
self,
100100
addr: Address,
101101
msg: RawMessage,
@@ -290,9 +290,10 @@ async def publish_multi(
290290
msg: RawMessage,
291291
probe: bool = True,
292292
) -> AsyncIterator[RawMessage]:
293-
msgs = self._channel.publish_multi(
293+
msgs = await self._channel.publish(
294294
addr,
295295
msg,
296+
stream=True,
296297
probe=probe,
297298
)
298299
async for msg in msgs:

coagent/runtimes/local_runtime.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,6 @@ async def connect(self) -> None:
4444
async def close(self) -> None:
4545
pass
4646

47-
async def publish(
48-
self,
49-
addr: Address,
50-
msg: RawMessage,
51-
request: bool = False,
52-
stream: bool = False,
53-
reply: str = "",
54-
timeout: float = 0.5,
55-
probe: bool = True,
56-
) -> RawMessage | None:
57-
return await self._publish(
58-
addr,
59-
msg,
60-
request=request,
61-
stream=stream,
62-
reply=reply,
63-
timeout=timeout,
64-
probe=probe,
65-
)
66-
6747
async def subscribe(
6848
self,
6949
addr: Address,

coagent/runtimes/nats_runtime.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,6 @@ async def close(self) -> None:
5757
except ConnectionClosedError:
5858
pass
5959

60-
async def publish(
61-
self,
62-
addr: Address,
63-
msg: RawMessage,
64-
request: bool = False,
65-
stream: bool = False,
66-
reply: str = "",
67-
timeout: float = 0.5,
68-
probe: bool = True,
69-
) -> RawMessage | None:
70-
return await self._publish(
71-
addr,
72-
msg,
73-
request=request,
74-
stream=stream,
75-
reply=reply,
76-
timeout=timeout,
77-
probe=probe,
78-
)
79-
8060
async def subscribe(
8161
self,
8262
addr: Address,
@@ -102,14 +82,13 @@ async def _publish(
10282
reply: str = "",
10383
timeout: float = 0.5,
10484
probe: bool = True,
105-
nonblocking: bool = False,
10685
) -> RawMessage | None:
10786
if addr.is_reply or not probe or await self._probe(addr):
10887
return await self._nats_publish(
10988
addr, msg, request=request, stream=stream, reply=reply, timeout=timeout
11089
)
11190

112-
if request or not nonblocking:
91+
if request:
11392
# If in request-reply (or non-blocking) mode, always wait for the reply.
11493
return await self._create_and_publish(
11594
addr,

examples/rich_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ async def ainit(cls):
4040
async def asend(cls, query: str) -> AsyncIterator[ChatMessage]:
4141
msg = ChatMessage(role="user", content=query)
4242
cls.history.messages.append(msg)
43-
result = cls.runtime.channel.publish_multi(
43+
result = await cls.runtime.channel.publish(
4444
cls.addr,
4545
cls.history.encode(),
46+
stream=True,
4647
)
4748
full_reply = ChatMessage(role="assistant", content="")
4849
async for chunk in result:

examples/rich_client_textarea.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ async def ainit(cls):
3737
async def asend(cls, query: str) -> AsyncIterator[str]:
3838
msg = ChatMessage(role="user", content=query)
3939
cls.history.messages.append(msg)
40-
result = cls.runtime.channel.publish_multi(
40+
result = await cls.runtime.channel.publish(
4141
cls.addr,
4242
cls.history.encode(),
43+
stream=True,
4344
)
4445
content = ""
4546
async for chunk in result:

0 commit comments

Comments
 (0)