diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index 717c507b2..99406ef40 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -8,6 +8,8 @@ from langgraph.graph.graph import CompiledGraph from langsmith import Client +from codegen.agents.loggers import ExternalLogger +from codegen.agents.tracer import MessageStreamTracer from codegen.extensions.langchain.agent import create_codebase_agent from codegen.extensions.langchain.utils.get_langsmith_url import ( find_and_print_langsmith_run_url, @@ -30,6 +32,7 @@ class CodeAgent: run_id: str | None = None instance_id: str | None = None difficulty: int | None = None + logger: Optional[ExternalLogger] = None def __init__( self, @@ -42,6 +45,7 @@ def __init__( metadata: Optional[dict] = {}, agent_config: Optional[AgentConfig] = None, thread_id: Optional[str] = None, + logger: Optional[ExternalLogger] = None, **kwargs, ): """Initialize a CodeAgent. @@ -92,6 +96,9 @@ def __init__( # Initialize tags for agent trace self.tags = [*tags, self.model_name] + # set logger if provided + self.logger = logger + # Initialize metadata for agent trace self.metadata = { "project": self.project_name, @@ -123,19 +130,26 @@ def run(self, prompt: str) -> str: config = RunnableConfig(configurable={"thread_id": self.thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=200) # we stream the steps instead of invoke because it allows us to access intermediate nodes + stream = self.agent.stream(input, config=config, stream_mode="values") + _tracer = MessageStreamTracer(logger=self.logger) + + # Process the stream with the tracer + traced_stream = _tracer.process_stream(stream) + # Keep track of run IDs from the stream run_ids = [] - for s in stream: + for s in traced_stream: if len(s["messages"]) == 0 or isinstance(s["messages"][-1], HumanMessage): message = HumanMessage(content=prompt) else: message = s["messages"][-1] if isinstance(message, tuple): - print(message) + # print(message) + pass else: if isinstance(message, AIMessage) and isinstance(message.content, list) and len(message.content) > 0 and "text" in message.content[0]: AIMessage(message.content[0]["text"]).pretty_print() @@ -149,7 +163,7 @@ def run(self, prompt: str) -> str: # Get the last message content result = s["final_answer"] - # Try to find run IDs in the LangSmith client's recent runs + # # Try to find run IDs in the LangSmith client's recent runs try: # Find and print the LangSmith run URL find_and_print_langsmith_run_url(self.langsmith_client, self.project_name) diff --git a/src/codegen/agents/data.py b/src/codegen/agents/data.py new file mode 100644 index 000000000..6ac9b1d81 --- /dev/null +++ b/src/codegen/agents/data.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Literal, Optional, Union + + +# Base dataclass for all message types +@dataclass +class BaseMessage: + """Base class for all message types.""" + + type: str + timestamp: str = field(default_factory=lambda: datetime.now(tz=UTC).isoformat()) + content: str = "" + + +@dataclass +class UserMessage(BaseMessage): + """Represents a message from the user.""" + + type: Literal["user"] = field(default="user") + + +@dataclass +class SystemMessageData(BaseMessage): + """Represents a system message.""" + + type: Literal["system"] = field(default="system") + + +@dataclass +class ToolCall: + """Represents a tool call within an assistant message.""" + + name: Optional[str] = None + arguments: Optional[str] = None + id: Optional[str] = None + + +@dataclass +class AssistantMessage(BaseMessage): + """Represents a message from the assistant.""" + + type: Literal["assistant"] = field(default="assistant") + tool_calls: list[ToolCall] = field(default_factory=list) + + +@dataclass +class ToolMessageData(BaseMessage): + """Represents a tool response message.""" + + type: Literal["tool"] = field(default="tool") + tool_name: Optional[str] = None + tool_response: Optional[str] = None + tool_id: Optional[str] = None + + +@dataclass +class FunctionMessageData(BaseMessage): + """Represents a function message.""" + + type: Literal["function"] = field(default="function") + + +@dataclass +class UnknownMessage(BaseMessage): + """Represents an unknown message type.""" + + type: Literal["unknown"] = field(default="unknown") + + +type AgentRunMessage = Union[UserMessage, SystemMessageData, AssistantMessage, ToolMessageData, FunctionMessageData, UnknownMessage] diff --git a/src/codegen/agents/loggers.py b/src/codegen/agents/loggers.py new file mode 100644 index 000000000..b507c427c --- /dev/null +++ b/src/codegen/agents/loggers.py @@ -0,0 +1,16 @@ +from typing import Protocol + +from .data import AgentRunMessage + + +# Define the interface for ExternalLogger +class ExternalLogger(Protocol): + """Protocol defining the interface for external loggers.""" + + def log(self, data: AgentRunMessage) -> None: + """Log structured data to an external system. + + Args: + data: The structured data to log, either as a dictionary or a BaseMessage + """ + pass diff --git a/src/codegen/agents/scratch.ipynb b/src/codegen/agents/scratch.ipynb new file mode 100644 index 000000000..1b50f8920 --- /dev/null +++ b/src/codegen/agents/scratch.ipynb @@ -0,0 +1,91 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from codegen.agents.code_agent import CodeAgent\n", + "\n", + "\n", + "CodeAgent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from codegen.sdk.core.codebase import Codebase\n", + "\n", + "\n", + "codebase = Codebase.from_repo(\"codegen-sh/Kevin-s-Adventure-Game\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any, Dict, Union\n", + "from codegen.agents.data import BaseMessage\n", + "from codegen.agents.loggers import ExternalLogger\n", + "\n", + "\n", + "class ConsoleLogger(ExternalLogger):\n", + " def log(self, data: Union[Dict[str, Any], BaseMessage]) -> None:\n", + " print(data.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent = CodeAgent(codebase)\n", + "agent.run(\"What is the main character's name? also show the source code where you find the answer\", logger=ConsoleLogger())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent.run(\"What is the main character's name?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/codegen/agents/tracer.py b/src/codegen/agents/tracer.py new file mode 100644 index 000000000..816835c41 --- /dev/null +++ b/src/codegen/agents/tracer.py @@ -0,0 +1,136 @@ +from collections.abc import Generator +from typing import Any, Optional + +from langchain.schema import AIMessage, HumanMessage +from langchain.schema import FunctionMessage as LCFunctionMessage +from langchain.schema import SystemMessage as LCSystemMessage +from langchain_core.messages import ToolMessage as LCToolMessage + +from .data import AssistantMessage, BaseMessage, FunctionMessageData, SystemMessageData, ToolCall, ToolMessageData, UnknownMessage, UserMessage +from .loggers import ExternalLogger + + +class MessageStreamTracer: + def __init__(self, logger: Optional[ExternalLogger] = None): + self.traces = [] + self.logger = logger + + def process_stream(self, message_stream: Generator) -> Generator: + """Process the stream of messages from the LangGraph agent, + extract structured data, and pass through the messages. + """ + for chunk in message_stream: + # Process the chunk + structured_data = self.extract_structured_data(chunk) + + # Log the structured data + if structured_data: + self.traces.append(structured_data) + + # If there's an external logger, send the data there + if self.logger: + self.logger.log(structured_data) + + # Pass through the chunk to maintain the original stream behavior + yield chunk + + def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage]: + """Extract structured data from a message chunk. + Returns None if the chunk doesn't contain useful information. + Returns a BaseMessage subclass instance based on the message type. + """ + # Get the messages from the chunk if available + messages = chunk.get("messages", []) + if not messages and isinstance(chunk, dict): + # Sometimes the message might be in a different format + for key, value in chunk.items(): + if isinstance(value, list) and all(hasattr(item, "type") for item in value if hasattr(item, "__dict__")): + messages = value + break + + if not messages: + return None + + # Get the latest message + latest_message = messages[-1] if messages else None + + if not latest_message: + return None + + # Determine message type + message_type = self._get_message_type(latest_message) + content = self._get_message_content(latest_message) + + # Create the appropriate message type + if message_type == "user": + return UserMessage(type=message_type, content=content) + elif message_type == "system": + return SystemMessageData(type=message_type, content=content) + elif message_type == "assistant": + tool_calls_data = self._extract_tool_calls(latest_message) + tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data] + return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls) + elif message_type == "tool": + return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None)) + elif message_type == "function": + return FunctionMessageData(type=message_type, content=content) + else: + return UnknownMessage(type=message_type, content=content) + + def _get_message_type(self, message) -> str: + """Determine the type of message.""" + if isinstance(message, HumanMessage): + return "user" + elif isinstance(message, AIMessage): + return "assistant" + elif isinstance(message, LCSystemMessage): + return "system" + elif isinstance(message, LCFunctionMessage): + return "function" + elif isinstance(message, LCToolMessage): + return "tool" + elif hasattr(message, "type") and message.type: + return message.type + else: + return "unknown" + + def _get_message_content(self, message) -> str: + """Extract content from a message.""" + if hasattr(message, "content"): + return message.content + elif hasattr(message, "message") and hasattr(message.message, "content"): + return message.message.content + else: + return str(message) + + def _extract_tool_calls(self, message) -> list[dict[str, Any]]: + """Extract tool calls from an assistant message.""" + tool_calls = [] + + # Check different possible locations for tool calls + if hasattr(message, "additional_kwargs") and "tool_calls" in message.additional_kwargs: + raw_tool_calls = message.additional_kwargs["tool_calls"] + for tc in raw_tool_calls: + tool_calls.append({"name": tc.get("function", {}).get("name"), "arguments": tc.get("function", {}).get("arguments"), "id": tc.get("id")}) + + # Also check for function_call which is used in some models + elif hasattr(message, "additional_kwargs") and "function_call" in message.additional_kwargs: + fc = message.additional_kwargs["function_call"] + if isinstance(fc, dict): + tool_calls.append( + { + "name": fc.get("name"), + "arguments": fc.get("arguments"), + "id": "function_call_1", # Assigning a default ID + } + ) + + return tool_calls + + def get_traces(self) -> list[BaseMessage]: + """Get all collected traces.""" + return self.traces + + def clear_traces(self) -> None: + """Clear all traces.""" + self.traces = []