Skip to content

Commit f71ac17

Browse files
committed
Refactor the streaming mechanism
1 parent 3c6cd5a commit f71ac17

File tree

13 files changed

+136
-50
lines changed

13 files changed

+136
-50
lines changed

coagent/agents/messages.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Any
24

35
from pydantic import Field
@@ -14,6 +16,15 @@ class ChatMessage(Message):
1416
default=False, description="Whether the message is sent directly to user."
1517
)
1618

19+
def __add__(self, other: ChatMessage) -> ChatMessage:
20+
return ChatMessage(
21+
role=self.role,
22+
content=self.content + other.content,
23+
type=self.type,
24+
sender=self.sender,
25+
to_user=self.to_user,
26+
)
27+
1728
def model_dump(self, **kwargs) -> dict[str, Any]:
1829
return super().model_dump(include={"role", "content"}, **kwargs)
1930

coagent/agents/parallel.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
GenericMessage,
88
Message,
99
RawMessage,
10+
Reply,
1011
SetReplyAgent,
1112
)
1213

1314

1415
class StartAggregation(Message):
1516
candidates: list[str]
16-
reply_addr: Address
17+
reply_info: Reply | None
1718

1819

1920
class AggregationStatus(Message):
@@ -62,17 +63,17 @@ async def handle(self, msg: GenericMessage, ctx: Context) -> None:
6263
self._results.append(msg.encode())
6364

6465
if len(self._results) == len(self._data.candidates):
65-
if self._data.reply_addr:
66+
if self._data.reply_info:
6667
result = await self._aggregate(self._results)
67-
await self.channel.publish(self._data.reply_addr, result)
68+
await self.send_reply(self._data.reply_info, result)
6869
self._busy = False
6970

70-
async def aggregate(self, results: list[RawMessage]) -> RawMessage:
71+
async def aggregate(self, results: list[RawMessage]) -> Message:
7172
"""Aggregate the results to a single one.
7273
7374
Override this method to provide custom aggregation logic.
7475
"""
75-
return AggregationResult(results=results).encode()
76+
return AggregationResult(results=results)
7677

7778

7879
class Parallel(BaseAgent):
@@ -87,12 +88,13 @@ def __init__(self, *agent_types: str, aggregator: str = ""):
8788

8889
async def started(self) -> None:
8990
aggregator_addr = Address(name=self._aggregator_type, id=self.address.id)
91+
aggregator_reply = Reply(address=aggregator_addr)
9092
# Make each agent reply to the aggregator agent.
9193
for agent_type in self._agent_types:
9294
addr = Address(name=agent_type, id=self.address.id)
9395
await self.channel.publish(
9496
addr,
95-
SetReplyAgent(address=aggregator_addr).encode(),
97+
SetReplyAgent(reply_info=aggregator_reply).encode(),
9698
)
9799

98100
@handler
@@ -101,16 +103,14 @@ async def handle(self, msg: GenericMessage, ctx: Context) -> None:
101103
return
102104

103105
# Let the aggregator agent reply to the sending agent, if asked.
104-
reply_address = self.reply_address or msg.reply
105-
if reply_address:
106+
reply = self.reply or msg.reply
107+
if reply:
106108
# Reset the reply address of the message, since it will be replied by the aggregator agent.
107109
msg.reply = None
108110

109111
result = await self.channel.publish(
110112
Address(name=self._aggregator_type, id=self.address.id),
111-
StartAggregation(
112-
candidates=self._agent_types, reply_addr=reply_address
113-
).encode(),
113+
StartAggregation(candidates=self._agent_types, reply_info=reply).encode(),
114114
request=True,
115115
)
116116
status = AggregationStatus.decode(result)

coagent/agents/sequential.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
BaseAgent,
44
Context,
55
GenericMessage,
6+
Reply,
67
SetReplyAgent,
78
handler,
89
)
@@ -20,9 +21,10 @@ async def started(self) -> None:
2021
# Set the reply address of the current agent to be the next agent.
2122
addr = Address(name=self._agent_types[i], id=self.address.id)
2223
next_addr = Address(name=self._agent_types[i + 1], id=self.address.id)
24+
reply = Reply(address=next_addr)
2325
await self.channel.publish(
2426
addr,
25-
SetReplyAgent(address=next_addr).encode(),
27+
SetReplyAgent(reply_info=reply).encode(),
2628
)
2729

2830
@handler
@@ -31,12 +33,12 @@ async def handle(self, msg: GenericMessage, ctx: Context) -> None:
3133
return
3234

3335
# Let the last agent reply to the sending agent, if asked.
34-
reply_address = self.reply_address or msg.reply
35-
if reply_address:
36+
reply = self.reply or msg.reply
37+
if reply:
3638
last_addr = Address(name=self._agent_types[-1], id=self.address.id)
3739
await self.channel.publish(
3840
last_addr,
39-
SetReplyAgent(address=reply_address).encode(),
41+
SetReplyAgent(reply_info=reply).encode(),
4042
)
4143
# Reset the reply address of the message, since it will be replied by the last agent.
4244
msg.reply = None

coagent/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
MessageHeader,
1818
new,
1919
RawMessage,
20+
Reply,
2021
Subscription,
2122
)
2223
from .util import idle_loop

coagent/core/agent.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic import BaseModel, ValidationError
66

7-
from .exceptions import MessageDecodeError, InternalError
7+
from .exceptions import MessageDecodeError, InternalError, StreamError
88
from .logger import logger
99
from .messages import (
1010
Cancel,
@@ -17,7 +17,7 @@
1717
SetReplyAgent,
1818
StopIteration,
1919
)
20-
from .types import Address, Agent, Channel, RawMessage, State, Subscription
20+
from .types import Address, Agent, Channel, RawMessage, Reply, State, Subscription
2121

2222

2323
class Context:
@@ -122,8 +122,8 @@ def __init__(self, timeout: float = 60):
122122
# this would result in a lot of messages.
123123
self._lock: asyncio.Lock = asyncio.Lock()
124124

125-
# Normally reply_address is set by an orchestration agent by sending a `SetReplyAgent` message.
126-
self.reply_address: Address | None = None
125+
# Normally `reply` is set by an orchestration agent by sending a `SetReplyAgent` message.
126+
self.reply: Reply | None = None
127127

128128
handlers, message_types = self.__collect_handlers()
129129
# A list of handlers that are registered to handle messages.
@@ -269,7 +269,7 @@ async def _handle_data(self) -> None:
269269
await self.stopped()
270270

271271
case SetReplyAgent():
272-
self.reply_address = msg.address
272+
self.reply = msg.reply_info
273273

274274
case ProbeAgent() | Empty():
275275
# Do not handle probes and empty messages.
@@ -282,18 +282,41 @@ async def _handle_data_custom(self, msg: Message, ctx: Context) -> None:
282282
"""Handle user-defined DATA messages."""
283283
h: Handler = self.__get_handler(msg)
284284
result = h(self, msg, ctx)
285+
if not is_async_iterator(result):
286+
result = await result or Empty()
287+
await self.__send_reply(msg.reply, result)
288+
289+
async def __send_reply(
290+
self, in_msg_reply: Reply, result: Message | AsyncIterator[Message]
291+
) -> bool:
292+
reply = self.reply or in_msg_reply
293+
if not reply:
294+
return False
295+
296+
# Reply to the sender if asked.
297+
await self.send_reply(reply, result)
298+
return True
285299

286-
async def pub(x: Message):
287-
await self.__send_reply(msg.reply, x)
300+
async def send_reply(
301+
self,
302+
to: Reply,
303+
result: Message | AsyncIterator[Message],
304+
) -> None:
305+
async def pub(msg: Message):
306+
await self.channel.publish(to.address, msg.encode())
288307

289308
async def pub_exc(exc: BaseException):
290309
err = InternalError.from_exception(exc)
291310
await pub(err.encode_message())
292311

293-
if is_async_iterator(result):
312+
if to.stream:
294313
try:
295-
async for x in result:
296-
await pub(x)
314+
if is_async_iterator(result):
315+
async for msg in result:
316+
await pub(msg)
317+
else:
318+
msg = result
319+
await pub(msg)
297320
except asyncio.CancelledError as exc:
298321
await pub_exc(exc)
299322
raise
@@ -304,23 +327,26 @@ async def pub_exc(exc: BaseException):
304327
await pub(StopIteration())
305328
else:
306329
try:
307-
x = await result or Empty()
308-
await pub(x)
330+
if is_async_iterator(result):
331+
accumulated: RawMessage | None = None
332+
async for msg in result:
333+
if not accumulated:
334+
accumulated = msg
335+
else:
336+
try:
337+
accumulated += msg
338+
except TypeError:
339+
await pub_exc(StreamError("Streaming mode is required"))
340+
await pub(accumulated)
341+
else:
342+
msg = result
343+
await pub(msg)
309344
except asyncio.CancelledError as exc:
310345
await pub_exc(exc)
311346
raise
312347
except Exception as exc:
313348
await pub_exc(exc)
314349

315-
async def __send_reply(self, in_msg_reply: Address, out_msg: Message) -> bool:
316-
reply_address = self.reply_address or in_msg_reply
317-
if not reply_address:
318-
return False
319-
320-
# Reply to the sending agent if asked.
321-
await self.channel.publish(reply_address, out_msg.encode())
322-
return True
323-
324350
def __get_handler(self, msg: Message) -> Handler | None:
325351
msg_type: Type[Any] = type(msg)
326352

coagent/core/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,7 @@ def from_exception(
9999

100100
class DeadlineExceededError(BaseError):
101101
"""Raised when a context deadline is exceeded."""
102+
103+
104+
class StreamError(BaseError):
105+
"""Raised when the sender requests a non-streaming result but the receiver sends a stream."""

coagent/core/messages.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,26 @@
33
import json
44
from pydantic import BaseModel, ConfigDict, Field, ValidationError
55

6-
from .types import Address, MessageHeader, RawMessage
6+
from .types import MessageHeader, RawMessage, Reply
77

88

99
class Message(BaseModel):
1010
model_config = ConfigDict(extra="forbid")
1111

12-
reply: Address | None = Field(default=None, description="Reply address.")
12+
reply: Reply | None = Field(default=None, description="Reply information.")
1313
extensions: dict = Field(
1414
default_factory=dict, description="Extension fields from RawMessage header."
1515
)
1616

17+
def __add__(self, other: Message) -> Message:
18+
"""Concatenate two messages.
19+
20+
This binary operator is mainly used to aggregate multiple streaming
21+
messages into one message when the sender requests to receive a
22+
non-streaming result.
23+
"""
24+
return NotImplemented
25+
1726
def encode(
1827
self, content_type: str = "application/json", exclude_defaults: bool = True
1928
) -> RawMessage:
@@ -114,9 +123,12 @@ class ProbeAgent(Message):
114123

115124

116125
class SetReplyAgent(Message):
117-
"""A message to set the agent to reply to."""
126+
"""A message to set the reply information of an agent.
127+
128+
This is mainly useful when orchestrating multiple agents to work together.
129+
"""
118130

119-
address: Address
131+
reply_info: Reply
120132

121133

122134
class Empty(Message):

coagent/core/runtime.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ async def publish_multi(
8787
addr,
8888
msg,
8989
request=True,
90+
stream=True,
9091
reply=inbox,
9192
probe=probe,
9293
)

coagent/core/types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def decode(cls, data: dict) -> Address:
8888
return cls.model_validate(data)
8989

9090

91+
class Reply(BaseModel):
92+
address: Address = Field(..., description="Reply address.")
93+
stream: bool = Field(
94+
False, description="Whether the sender requests a streaming result."
95+
)
96+
97+
9198
class MessageHeader(BaseModel):
9299
type: str = Field(..., description="Message type name.")
93100
content_type: str = Field(
@@ -98,7 +105,7 @@ class MessageHeader(BaseModel):
98105

99106
class RawMessage(BaseModel):
100107
header: MessageHeader = Field(..., description="Message header.")
101-
reply: Address | None = Field(default=None, description="Reply address.")
108+
reply: Reply | None = Field(default=None, description="Reply information.")
102109
content: bytes = Field(default=b"", description="Message content.")
103110

104111
def encode(self, mode: str = "python", exclude_defaults: bool = True) -> dict:
@@ -223,6 +230,7 @@ async def publish(
223230
addr: Address,
224231
msg: RawMessage,
225232
request: bool = False,
233+
stream: bool = False,
226234
reply: str = "",
227235
timeout: float = 0.5,
228236
probe: bool = True,
@@ -233,6 +241,7 @@ async def publish(
233241
addr (Address): The address of the agent.
234242
msg (RawMessage): The raw message to send.
235243
request (bool, optional): Whether this is a request. Defaults to False.
244+
stream (bool, optional): Whether to request a streaming result. Defaults to False.
236245
reply (str, optional): If `request` is True, then this will be the subject to reply to. Defaults to "".
237246
timeout (float, optional): If `request` is True, then this will be the timeout for the response. Defaults to 0.5.
238247
probe (bool, optional): Whether to probe the agent before sending the message. Defaults to True.

coagent/runtimes/http_runtime.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ async def publish(
6969
addr: Address,
7070
msg: RawMessage,
7171
request: bool = False,
72+
stream: bool = False,
7273
reply: str = "",
7374
timeout: float = 5.0,
7475
probe: bool = True,
@@ -77,6 +78,7 @@ async def publish(
7778
addr=addr.encode(mode="json"),
7879
msg=msg.encode(mode="json"),
7980
request=request,
81+
stream=stream,
8082
reply=reply,
8183
timeout=timeout,
8284
probe=probe,
@@ -267,6 +269,7 @@ async def publish(
267269
addr: Address,
268270
msg: RawMessage,
269271
request: bool = False,
272+
stream: bool = False,
270273
reply: str = "",
271274
timeout: float = 5.0,
272275
probe: bool = True,
@@ -275,6 +278,7 @@ async def publish(
275278
addr,
276279
msg,
277280
request=request,
281+
stream=stream,
278282
reply=reply,
279283
timeout=timeout,
280284
probe=probe,

0 commit comments

Comments
 (0)