Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/codegen/agents/code_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class CodeAgent:
"""Agent for interacting with a codebase."""

def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", memory: bool = True, tools: Optional[list[BaseTool]] = None, **kwargs):
def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", model_name: str = "claude-3-7-sonnet-latest", memory: bool = True, tools: Optional[list[BaseTool]] = None, **kwargs):
"""Initialize a CodeAgent.

Args:
Expand Down
23 changes: 21 additions & 2 deletions src/codegen/extensions/langchain/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""LLM implementation supporting both OpenAI and Anthropic models."""

import logging
import os
from collections.abc import Sequence
from typing import Any, Optional
Expand All @@ -15,6 +16,10 @@
from langchain_openai import ChatOpenAI
from pydantic import Field

from codegen.extensions.langchain.utils.retry import retry_on_rate_limit

logger = logging.getLogger(__name__)


class LLM(BaseChatModel):
"""A unified chat model that supports both OpenAI and Anthropic."""
Expand All @@ -31,6 +36,10 @@ class LLM(BaseChatModel):

max_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate.", ge=1)

max_retries: int = Field(default=3, description="Maximum number of retries for rate limit errors.")

retry_base_delay: float = Field(default=45.0, description="Base delay in seconds for retry backoff.")

def __init__(self, model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", **kwargs: Any) -> None:
"""Initialize the LLM.

Expand All @@ -42,13 +51,15 @@ def __init__(self, model_provider: str = "anthropic", model_name: str = "claude-
- top_p: Top-p sampling parameter (0-1)
- top_k: Top-k sampling parameter (>= 1)
- max_tokens: Maximum number of tokens to generate
- max_retries: Maximum number of retries for rate limit errors
- retry_base_delay: Base delay in seconds for retry backoff
"""
# Set model provider and name before calling super().__init__
kwargs["model_provider"] = model_provider
kwargs["model_name"] = model_name

# Filter out unsupported kwargs
supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata"}
supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata", "max_retries", "retry_base_delay"}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in supported_kwargs}

super().__init__(**filtered_kwargs)
Expand Down Expand Up @@ -96,6 +107,7 @@ def _get_model(self) -> BaseChatModel:
msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai"
raise ValueError(msg)

@retry_on_rate_limit(max_retries=3, base_delay=45.0)
def _generate(
self,
messages: list[BaseMessage],
Expand All @@ -114,7 +126,14 @@ def _generate(
Returns:
ChatResult containing the generated completion
"""
return self._model._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
# Use instance-specific retry settings if provided
retry_decorator = retry_on_rate_limit(max_retries=self.max_retries, base_delay=self.retry_base_delay)

# Apply the retry decorator to the underlying model's _generate method
# This is a bit of a hack, but it allows us to use the decorator with the instance settings
generate_with_retry = retry_decorator(self._model._generate)

return generate_with_retry(messages, stop=stop, run_manager=run_manager, **kwargs)

def bind_tools(
self,
Expand Down
75 changes: 75 additions & 0 deletions src/codegen/extensions/langchain/utils/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Retry utilities for handling rate limits and other transient errors."""

import asyncio
import functools
import logging
import time
from typing import Any, Callable, TypeVar, cast

import anthropic
import openai

logger = logging.getLogger(__name__)

# Type variable for the decorator
T = TypeVar("T")


def retry_on_rate_limit(max_retries: int = 3, base_delay: float = 45.0) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""Decorator to retry functions on rate limit errors with exponential backoff.

Args:
max_retries: Maximum number of retry attempts
base_delay: Base delay in seconds between retries (will be multiplied by 2^retry_count)

Returns:
Decorated function with retry logic
"""

def decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> T:
retries = 0
while True:
try:
return func(*args, **kwargs)
except (openai.RateLimitError, anthropic.RateLimitError) as e:
retries += 1
if retries > max_retries:
logger.exception(f"Rate limit exceeded after {max_retries} retries. Giving up.")
raise

# Calculate delay with exponential backoff: base_delay * 2^(retry_count-1)
delay = base_delay * (2 ** (retries - 1))
logger.warning(f"Rate limit hit. Retrying in {delay:.1f} seconds... (Attempt {retries}/{max_retries})")
time.sleep(delay)
except Exception as e:
# Re-raise other exceptions
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise a rate limit error here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure


@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
retries = 0
while True:
try:
return await func(*args, **kwargs)
except (openai.RateLimitError, anthropic.RateLimitError) as e:
retries += 1
if retries > max_retries:
logger.exception(f"Rate limit exceeded after {max_retries} retries. Giving up.")
raise

# Calculate delay with exponential backoff: base_delay * 2^(retry_count-1)
delay = base_delay * (2 ** (retries - 1))
logger.warning(f"Rate limit hit. Retrying in {delay:.1f} seconds... (Attempt {retries}/{max_retries})")
await asyncio.sleep(delay)
except Exception as e:
# Re-raise other exceptions
raise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure


# Return the appropriate wrapper based on whether the function is async or not
if asyncio.iscoroutinefunction(func):
return cast(Callable[..., T], async_wrapper)
return cast(Callable[..., T], wrapper)

return decorator
Loading