Skip to content

Commit e1453d0

Browse files
committed
Differentiate ChatAgent from StreamChatAgent
1 parent 953c272 commit e1453d0

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

coagent/agents/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ruff: noqa: F401
2-
from .chat_agent import ChatAgent, confirm, RunContext, tool
2+
from .chat_agent import ChatAgent, confirm, RunContext, StreamChatAgent, tool
33
from .dynamic_triage import DynamicTriage
44
from .messages import ChatHistory, ChatMessage
55
from .model_client import ModelClient

coagent/agents/chat_agent.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
114114
return run
115115

116116

117-
class ChatAgent(BaseAgent):
117+
class StreamChatAgent(BaseAgent):
118118
def __init__(
119119
self,
120120
name: str = "",
@@ -168,7 +168,23 @@ async def agent(self, agent_type: str) -> AsyncIterator[ChatMessage]:
168168
yield chunk
169169

170170
@handler
171-
async def handle(
171+
async def handle_history(
172+
self, msg: ChatHistory, ctx: Context
173+
) -> AsyncIterator[ChatMessage]:
174+
response = self._handle_history(msg, ctx)
175+
async for resp in response:
176+
yield resp
177+
178+
@handler
179+
async def handle_message(
180+
self, msg: ChatMessage, ctx: Context
181+
) -> AsyncIterator[ChatMessage]:
182+
history = ChatHistory(messages=[msg])
183+
response = self._handle_history(history, ctx)
184+
async for resp in response:
185+
yield resp
186+
187+
async def _handle_history(
172188
self, msg: ChatHistory, ctx: Context
173189
) -> AsyncIterator[ChatMessage]:
174190
# For now, we assume that the agent is processing messages sequentially.
@@ -205,3 +221,23 @@ async def update_user_confirmed(self, history: ChatHistory) -> None:
205221
async def _has_confirm_message(self, history: ChatHistory) -> bool:
206222
"""Check if the penultimate message is a confirmation message."""
207223
return len(history.messages) > 1 and history.messages[-2].type == "confirm"
224+
225+
226+
class ChatAgent(StreamChatAgent):
227+
"""Non-streaming ChatAgent."""
228+
229+
@handler
230+
async def handle_history(self, msg: ChatHistory, ctx: Context) -> ChatMessage:
231+
accumulated_response = ChatMessage(role="assistant", content="")
232+
response = super().handle_history(msg, ctx)
233+
async for chunk in response:
234+
accumulated_response.content += chunk.content
235+
return accumulated_response
236+
237+
@handler
238+
async def handle_message(self, msg: ChatMessage, ctx: Context) -> ChatMessage:
239+
accumulated_response = ChatMessage(role="assistant", content="")
240+
response = super().handle_message(msg, ctx)
241+
async for chunk in response:
242+
accumulated_response.content += chunk.content
243+
return accumulated_response

0 commit comments

Comments
 (0)