Skip to content

Commit 105ef1c

Browse files
committed
Improve ChatAgent to handle StructuredOutput
1 parent d27c336 commit 105ef1c

File tree

3 files changed

+81
-21
lines changed

3 files changed

+81
-21
lines changed

coagent/agents/chat_agent.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .aswarm import Agent as SwarmAgent, Swarm
1414
from .aswarm.util import function_to_jsonschema
15-
from .messages import ChatMessage, ChatHistory
15+
from .messages import ChatMessage, ChatHistory, StructuredOutput
1616
from .model_client import default_model_client, ModelClient
1717
from .util import is_user_confirmed
1818

@@ -233,21 +233,38 @@ async def agent(self, agent_type: str) -> AsyncIterator[ChatMessage]:
233233
async def handle_history(
234234
self, msg: ChatHistory, ctx: Context
235235
) -> AsyncIterator[ChatMessage]:
236-
response = self._handle_history(msg, ctx)
236+
response = self._handle_history(msg)
237237
async for resp in response:
238238
yield resp
239239

240240
@handler
241241
async def handle_message(
242242
self, msg: ChatMessage, ctx: Context
243243
) -> AsyncIterator[ChatMessage]:
244-
history = ChatHistory(messages=[msg], response_format=msg.response_format)
245-
response = self._handle_history(history, ctx)
244+
history = ChatHistory(messages=[msg])
245+
response = self._handle_history(history)
246246
async for resp in response:
247247
yield resp
248248

249+
@handler
250+
async def handle_structured_output(
251+
self, msg: StructuredOutput, ctx: Context
252+
) -> AsyncIterator[ChatMessage]:
253+
match msg.input:
254+
case ChatMessage():
255+
history = ChatHistory(messages=[msg.input])
256+
response = self._handle_history(history, msg.output_schema)
257+
async for resp in response:
258+
yield resp
259+
case ChatHistory():
260+
response = self._handle_history(msg.input, msg.output_schema)
261+
async for resp in response:
262+
yield resp
263+
249264
async def _handle_history(
250-
self, msg: ChatHistory, ctx: Context
265+
self,
266+
msg: ChatHistory,
267+
response_format: dict | None = None,
251268
) -> AsyncIterator[ChatMessage]:
252269
# For now, we assume that the agent is processing messages sequentially.
253270
self._history: ChatHistory = msg
@@ -260,7 +277,7 @@ async def _handle_history(
260277
response = self._swarm_client.run_and_stream(
261278
agent=swarm_agent,
262279
messages=[m.model_dump() for m in msg.messages],
263-
response_format=msg.response_format,
280+
response_format=response_format,
264281
context_variables=msg.extensions,
265282
)
266283
async for resp in response:

coagent/agents/messages.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import Any, Type
44

5-
from pydantic import Field
6-
from coagent.core import Message
5+
from coagent.core import logger, Message
6+
from pydantic import BaseModel, Field, field_validator, field_serializer
77

88

99
class ChatMessage(Message):
1010
role: str
1111
content: str
12-
response_format: dict | None = None
1312

1413
type: str = Field(default="", description="The type of the message. e.g. confirm")
1514
sender: str = Field(default="", description="The sending agent of the message.")
@@ -30,4 +29,57 @@ def to_llm_message(self) -> dict[str, Any]:
3029

3130
class ChatHistory(Message):
3231
messages: list[ChatMessage]
33-
response_format: dict | None = None
32+
33+
34+
class StructuredOutput(Message):
35+
input: ChatMessage | ChatHistory = Field(..., description="Input message.")
36+
output_type: Type[BaseModel] | None = Field(
37+
None,
38+
description="Output schema specified as a Pydantic model. Equivalent to OpenAI's `response_format`.",
39+
)
40+
output_schema: dict | None = Field(
41+
None,
42+
description="Output schema specified as a dict. Setting this suppresses `output_type`.",
43+
)
44+
45+
@field_serializer("input")
46+
def serialize_input(self, value: Message, _info) -> dict:
47+
data = value.model_dump(exclude_defaults=True)
48+
data["__message_type__"] = value.__class__.__name__
49+
return data
50+
51+
@field_validator("input", mode="before")
52+
@classmethod
53+
def validate_input(cls, value: Message | dict) -> Message:
54+
if isinstance(value, dict):
55+
message_type = value.pop("__message_type__", None)
56+
match message_type:
57+
# Only support ChatMessage and ChatHistory for now.
58+
case "ChatMessage":
59+
return ChatMessage.model_validate(value)
60+
case "ChatHistory":
61+
return ChatHistory.model_validate(value)
62+
return value
63+
64+
@field_serializer("output_type")
65+
def serialize_output_type(self, value: Type[BaseModel] | None, _info) -> None:
66+
# Always return None for `output_type` since it will be converted to `output_schema`.
67+
return None
68+
69+
@field_serializer("output_schema")
70+
def serialize_output_schema(self, value: dict | None, _info) -> dict | None:
71+
if self.output_type:
72+
if value:
73+
logger.warning("Setting output_schema suppresses output_type")
74+
return value
75+
return type_to_response_format_param(self.output_type)
76+
77+
return value
78+
79+
80+
def type_to_response_format_param(
81+
response_format: Type[BaseModel] | dict | None,
82+
) -> dict | None:
83+
import litellm.utils
84+
85+
return litellm.utils.type_to_response_format_param(response_format)

coagent/agents/util.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
11
import asyncio
22
import functools
3-
from typing import AsyncIterator, Type
3+
from typing import AsyncIterator
44

55
from coagent.core import logger
6-
from pydantic import BaseModel
76

87
from .messages import ChatMessage
98
from .model_client import default_model_client, ModelClient
109

1110

12-
def type_to_response_format_param(
13-
response_format: Type[BaseModel] | dict | None,
14-
) -> dict | None:
15-
import litellm.utils
16-
17-
return litellm.utils.type_to_response_format_param(response_format)
18-
19-
2011
async def chat(
2112
messages: list[ChatMessage], client: ModelClient = default_model_client
2213
) -> ChatMessage:

0 commit comments

Comments
 (0)