diff --git a/codegen-examples/examples/langchain_agent/run.py b/codegen-examples/examples/langchain_agent/run.py index 0d4d4f837..5c6891889 100644 --- a/codegen-examples/examples/langchain_agent/run.py +++ b/codegen-examples/examples/langchain_agent/run.py @@ -20,7 +20,7 @@ from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.graph import CompiledGraph -from langgraph.prebuilt import create_react_agent +from codegen.extensions.langchain.graph import create_react_agent from langchain_core.messages import SystemMessage @@ -70,7 +70,7 @@ def create_codebase_agent( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) if __name__ == "__main__": diff --git a/codegen-examples/examples/swebench_agent_run/local_run.ipynb b/codegen-examples/examples/swebench_agent_run/local_run.ipynb index 1f27f470c..0b212fa40 100644 --- a/codegen-examples/examples/swebench_agent_run/local_run.ipynb +++ b/codegen-examples/examples/swebench_agent_run/local_run.ipynb @@ -7,7 +7,14 @@ "outputs": [], "source": [ "%load_ext autoreload\n", - "%autoreload 2" + "%autoreload 2\n", + "\n", + "from dotenv import load_dotenv # type: ignore\n", + "\n", + "load_dotenv()\n", + "\n", + "from codegen.extensions.swebench.utils import SWEBenchDataset, get_swe_bench_examples # noqa: E402\n", + "from run_eval import run_eval # noqa: E402" ] }, { @@ -16,9 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "from codegen.sdk.core.codebase import Codebase\n", - "from codegen.extensions.swebench.utils import SWEBenchDataset, get_swe_bench_examples\n", - "from run_eval import run_eval" + "examples = get_swe_bench_examples(dataset=SWEBenchDataset.LITE, split=\"test\", offset=0, length=10)" ] }, { @@ -27,43 +32,8 @@ "metadata": {}, "outputs": [], "source": [ - "examples = get_swe_bench_examples(dataset=SWEBenchDataset.LITE, split=\"test\", offset=0, length=1)" + "await run_eval(use_existing_preds=None, dataset=\"lite\", length=20, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "codebase = Codebase.from_repo(examples[0].repo, commit=examples[0].base_commit, tmp_dir=f\"/tmp/{examples[0].instance_id}\")\n", - "# this will allow us to reuse the codebase for multiple examples\n", - "codebases = {examples[0].instance_id: codebase}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "await run_eval(use_existing_preds=None, dataset=\"lite\", length=None, instance_id=examples[0].instance_id, local=True, codebases=codebases)\n", - "codebases[examples[0].instance_id].reset()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -82,7 +52,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.1" + "version": "3.13.0" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index bbbb9ccde..3f30ba581 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "tiktoken<1.0.0,>=0.5.1", "tabulate>=0.9.0,<1.0.0", "codeowners<1.0.0,>=0.6.0", + "anthropic", "dataclasses-json<1.0.0,>=0.6.4", "dicttoxml<2.0.0,>=1.7.16", "xmltodict<1.0.0,>=0.13.0", diff --git a/src/codegen/agents/chat_agent.py b/src/codegen/agents/chat_agent.py new file mode 100644 index 000000000..24ecada26 --- /dev/null +++ b/src/codegen/agents/chat_agent.py @@ -0,0 +1,95 @@ +from typing import TYPE_CHECKING, Optional +from uuid import uuid4 + +from langchain.tools import BaseTool +from langchain_core.messages import AIMessage + +from codegen.extensions.langchain.agent import create_chat_agent + +if TYPE_CHECKING: + from codegen import Codebase + + +class ChatAgent: + """Agent for interacting with a codebase.""" + + def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", memory: bool = True, tools: Optional[list[BaseTool]] = None, **kwargs): + """Initialize a CodeAgent. + + Args: + codebase: The codebase to operate on + model_provider: The model provider to use ("anthropic" or "openai") + model_name: Name of the model to use + memory: Whether to let LLM keep track of the conversation history + tools: Additional tools to use + **kwargs: Additional LLM configuration options. Supported options: + - temperature: Temperature parameter (0-1) + - top_p: Top-p sampling parameter (0-1) + - top_k: Top-k sampling parameter (>= 1) + - max_tokens: Maximum number of tokens to generate + """ + self.codebase = codebase + self.agent = create_chat_agent(self.codebase, model_provider=model_provider, model_name=model_name, memory=memory, additional_tools=tools, **kwargs) + + def run(self, prompt: str, thread_id: Optional[str] = None) -> str: + """Run the agent with a prompt. + + Args: + prompt: The prompt to run + thread_id: Optional thread ID for message history. If None, a new thread is created. + + Returns: + The agent's response + """ + if thread_id is None: + thread_id = str(uuid4()) + + input = {"query": prompt} + stream = self.agent.stream(input, config={"configurable": {"thread_id": thread_id}}, stream_mode="values") + + for s in stream: + message = s["messages"][-1] + if isinstance(message, tuple): + print(message) + else: + if isinstance(message, AIMessage) and isinstance(message.content, list) and "text" in message.content[0]: + AIMessage(message.content[0]["text"]).pretty_print() + else: + message.pretty_print() + + return s["final_answer"] + + def chat(self, prompt: str, thread_id: Optional[str] = None) -> tuple[str, str]: + """Chat with the agent, maintaining conversation history. + + Args: + prompt: The user message + thread_id: Optional thread ID for message history. If None, a new thread is created. + + Returns: + A tuple of (response_content, thread_id) to allow continued conversation + """ + if thread_id is None: + thread_id = str(uuid4()) + print(f"Starting new chat thread: {thread_id}") + else: + print(f"Continuing chat thread: {thread_id}") + + response = self.run(prompt, thread_id=thread_id) + return response, thread_id + + def get_chat_history(self, thread_id: str) -> list: + """Retrieve the chat history for a specific thread. + + Args: + thread_id: The thread ID to retrieve history for + + Returns: + List of messages in the conversation history + """ + # Access the agent's memory to get conversation history + if hasattr(self.agent, "get_state"): + state = self.agent.get_state({"configurable": {"thread_id": thread_id}}) + if state and "messages" in state: + return state["messages"] + return [] diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index d02cd08a2..03f061cab 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -3,7 +3,7 @@ from uuid import uuid4 from langchain.tools import BaseTool -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables.config import RunnableConfig from langsmith import Client @@ -94,8 +94,17 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: # this message has a reducer which appends the current message to the existing history # see more https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers - input = {"messages": [("user", prompt)]} - tags, metadata = self.get_tags_metadata() + input = {"query": prompt} + metadata = {"project": self.project_name} + tags = [] + # Add SWEBench run ID and instance ID to the metadata and tags for filtering + if self.run_id is not None: + metadata["swebench_run_id"] = self.run_id + tags.append(self.run_id) + + if self.instance_id is not None: + metadata["swebench_instance_id"] = self.instance_id + tags.append(self.instance_id) config = RunnableConfig(configurable={"thread_id": thread_id}, tags=tags, metadata=metadata, recursion_limit=100) # we stream the steps instead of invoke because it allows us to access intermediate nodes @@ -105,7 +114,11 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: run_ids = [] for s in stream: - message = s["messages"][-1] + if len(s["messages"]) == 0: + message = HumanMessage(content=prompt) + else: + message = s["messages"][-1] + if isinstance(message, tuple): print(message) else: @@ -119,7 +132,7 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str: run_ids.append(message.additional_kwargs["run_id"]) # Get the last message content - result = s["messages"][-1].content + result = s["final_answer"] # Try to find run IDs in the LangSmith client's recent runs try: diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index aabc57847..fe44594b1 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -6,11 +6,10 @@ from langchain_core.messages import SystemMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph.graph import CompiledGraph -from langgraph.prebuilt import create_react_agent -from .llm import LLM -from .prompts import REASONER_SYSTEM_MESSAGE -from .tools import ( +from codegen.extensions.langchain.llm import LLM +from codegen.extensions.langchain.prompts import REASONER_SYSTEM_MESSAGE +from codegen.extensions.langchain.tools import ( CreateFileTool, DeleteFileTool, ListDirectoryTool, @@ -25,6 +24,8 @@ ViewFileTool, ) +from .graph import create_react_agent + if TYPE_CHECKING: from codegen import Codebase @@ -88,7 +89,7 @@ def create_codebase_agent( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) def create_chat_agent( @@ -137,7 +138,7 @@ def create_chat_agent( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) def create_codebase_inspector_agent( @@ -174,7 +175,7 @@ def create_codebase_inspector_agent( ] memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) def create_agent_with_tools( @@ -208,4 +209,4 @@ def create_agent_with_tools( memory = MemorySaver() if memory else None - return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug) + return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug) diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py new file mode 100644 index 000000000..3685ea322 --- /dev/null +++ b/src/codegen/extensions/langchain/graph.py @@ -0,0 +1,102 @@ +"""Demo implementation of an agent with Codegen tools.""" + +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional + +import anthropic +import openai +from langchain.tools import BaseTool +from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, START +from langgraph.graph.message import add_messages +from langgraph.graph.state import CompiledGraph, StateGraph +from langgraph.prebuilt import ToolNode +from langgraph.pregel import RetryPolicy + + +class GraphState(dict[str, Any]): + """State of the graph.""" + + query: str + final_answer: str + messages: Annotated[list[AnyMessage], add_messages] + + +class AgentGraph: + """Main graph class for the agent.""" + + def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMessage): + self.model = model.bind_tools(tools) + self.tools = tools + self.system_message = system_message + + # =================================== NODES ==================================== + + # Reasoner node + def reasoner(self, state: GraphState) -> dict[str, Any]: + new_turn = len(state["messages"]) == 0 or isinstance(state["messages"][-1], AIMessage) + messages = state["messages"] + if new_turn: + query = state["query"] + messages.append(HumanMessage(content=query)) + + result = self.model.invoke([self.system_message, *messages]) + + if isinstance(result, AIMessage): + return {"messages": [*messages, result], "final_answer": result.content} + + return {"messages": [*messages, result]} + + # =================================== EDGE CONDITIONS ==================================== + def should_continue(self, state: GraphState) -> Literal["tools", END]: + messages = state["messages"] + last_message = messages[-1] + if hasattr(last_message, "tool_calls") and last_message.tool_calls: + return "tools" + return END + + # =================================== COMPILE GRAPH ==================================== + def create(self, checkpointer: Optional[MemorySaver] = None, debug: bool = False) -> CompiledGraph: + """Create and compile the graph.""" + builder = StateGraph(GraphState) + + # the retry policy has an initial interval, a backoff factor, and a max interval of controlling the + # amount of time between retries + retry_policy = RetryPolicy( + retry_on=[anthropic.RateLimitError, openai.RateLimitError], + max_attempts=10, + initial_interval=30.0, # Start with 30 second wait + backoff_factor=2, # Double the wait time each retry + max_interval=1000.0, # Cap at 1000 second max wait + jitter=True, + ) + + # Add nodes + builder.add_node("reasoner", self.reasoner, retry=retry_policy) + builder.add_node("tools", ToolNode(self.tools), retry=retry_policy) + + # Add edges + builder.add_edge(START, "reasoner") + builder.add_edge("tools", "reasoner") + builder.add_conditional_edges( + "reasoner", + self.should_continue, + ) + + return builder.compile(checkpointer=checkpointer, debug=debug) + + +def create_react_agent( + model: "LLM", + tools: list[BaseTool], + system_message: SystemMessage, + checkpointer: Optional[MemorySaver] = None, + debug: bool = False, +) -> CompiledGraph: + """Create a reactive agent graph.""" + graph = AgentGraph(model, tools, system_message) + return graph.create(checkpointer=checkpointer, debug=debug) + + +if TYPE_CHECKING: + from codegen.extensions.langchain.llm import LLM diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index d2cec02a3..54b9a91a2 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -89,13 +89,13 @@ def _get_model(self) -> BaseChatModel: if not os.getenv("ANTHROPIC_API_KEY"): msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - return ChatAnthropic(**self._get_model_kwargs()) + return ChatAnthropic(**self._get_model_kwargs(), max_retries=10, timeout=1000) elif self.model_provider == "openai": if not os.getenv("OPENAI_API_KEY"): msg = "OPENAI_API_KEY not found in environment. Please set it in your .env file or environment variables." raise ValueError(msg) - return ChatOpenAI(**self._get_model_kwargs()) + return ChatOpenAI(**self._get_model_kwargs(), max_retries=10, timeout=1000) elif self.model_provider == "xai": if not os.getenv("XAI_API_KEY"): diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 7cda6d7f9..2ab5b0179 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -116,7 +116,8 @@ class SearchInput(BaseModel): query: str = Field( ..., - description="The search query to find in the codebase. When ripgrep is available, this will be passed as a ripgrep pattern. For regex searches, set use_regex=True. Ripgrep is the preferred method.", + description="""The search query to find in the codebase. When ripgrep is available, this will be passed as a ripgrep pattern. For regex searches, set use_regex=True. + Ripgrep is the preferred method.""", ) target_directories: Optional[list[str]] = Field(default=None, description="Optional list of directories to search in") file_extensions: Optional[list[str]] = Field(default=None, description="Optional list of file extensions to search (e.g. ['.py', '.ts'])") @@ -849,31 +850,45 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: class ReplacementEditInput(BaseModel): - filepath: str = Field(..., description="Path to the file to edit relative to the workspace root. The file must exist and be a text file.") + """Input for replacement editing.""" + + filepath: str = Field( + ..., + description=("Path to the file to edit relative to the workspace root. The file must exist and be a text file."), + ) pattern: str = Field( ..., - description="""Regular expression pattern to match text that should be replaced. -Supports all Python regex syntax including capture groups (\1, \2, etc). The pattern is compiled with re.MULTILINE flag by default.""", + description=( + "Regular expression pattern to match text that should be replaced. " + "Supports all Python regex syntax including capture groups (\\1, \\2, etc). " + "The pattern is compiled with re.MULTILINE flag by default." + ), ) replacement: str = Field( ..., - description="""Text to replace matched patterns with. -Can reference regex capture groups using \1, \2, etc. If using regex groups in pattern, make sure to preserve them in replacement if needed.""", + description=( + "Text to replace matched patterns with. Can reference regex capture groups using \\1, \\2, etc. If using regex groups in pattern, make sure to preserve them in replacement if needed." + ), ) start: int = Field( default=1, - description="""Starting line number (1-indexed, inclusive) to begin replacements from. -Use this with 'end' to limit changes to a specific region. Default is 1 (start of file).""", + description=("Starting line number (1-indexed, inclusive) to begin replacements from. Use this with 'end' to limit changes to a specific region. Default is 1 (start of file)."), ) end: int = Field( default=-1, - description="""Ending line number (1-indexed, inclusive) to stop replacements at. -Use -1 to indicate end of file. Use this with 'start' to limit changes to a specific region. Default is -1 (end of file).""", + description=( + "Ending line number (1-indexed, inclusive) to stop replacements at. " + "Use -1 to indicate end of file. Use this with 'start' to limit changes to a specific region. " + "Default is -1 (end of file)." + ), ) count: Optional[int] = Field( default=None, - description="""Maximum number of replacements to make. Use None to replace all occurrences (default), or specify a number to limit replacements. -Useful when you only want to replace the first N occurrences.""", + description=( + "Maximum number of replacements to make. " + "Use None to replace all occurrences (default), or specify a number to limit replacements. " + "Useful when you only want to replace the first N occurrences." + ), ) diff --git a/src/codegen/extensions/tools/replacement_edit.py b/src/codegen/extensions/tools/replacement_edit.py index 0c61c8f96..aa5cd98be 100644 --- a/src/codegen/extensions/tools/replacement_edit.py +++ b/src/codegen/extensions/tools/replacement_edit.py @@ -30,6 +30,14 @@ class ReplacementEditObservation(Observation): default=None, description="Message describing the result", ) + error: Optional[str] = Field( + default=None, + description="Error message if an error occurred", + ) + error_pattern: Optional[str] = Field( + default=None, + description="Regex pattern that failed to compile", + ) str_template: ClassVar[str] = "{message}" if "{message}" else "Edited file {filepath}" @@ -138,8 +146,13 @@ def replacement_edit( # Compile pattern for better error messages regex = re.compile(pattern, flags) except re.error as e: - msg = f"Invalid regex pattern: {e}" - raise ValueError(msg) + return ReplacementEditObservation( + status="error", + error=f"Invalid regex pattern: {e!s}", + error_pattern=pattern, + filepath=filepath, + message="Invalid regex pattern", + ) # Perform the replacement if count is None: diff --git a/src/codegen/extensions/tools/search.py b/src/codegen/extensions/tools/search.py index 4bcdfb74e..08a52c48b 100644 --- a/src/codegen/extensions/tools/search.py +++ b/src/codegen/extensions/tools/search.py @@ -146,7 +146,7 @@ def _search_with_ripgrep( pass # Add the query and path - cmd.append(query) + cmd.append(f"{query}") cmd.append(search_path) # Run ripgrep