diff --git a/src/codegen/agents/code_agent.py b/src/codegen/agents/code_agent.py index cb7be3ffa..cc072c0fc 100644 --- a/src/codegen/agents/code_agent.py +++ b/src/codegen/agents/code_agent.py @@ -16,7 +16,7 @@ class CodeAgent: """Agent for interacting with a codebase.""" - def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", memory: bool = True, tools: Optional[list[BaseTool]] = None, **kwargs): + def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", model_name: str = "claude-3-7-sonnet-latest", memory: bool = True, tools: Optional[list[BaseTool]] = None, **kwargs): """Initialize a CodeAgent. Args: diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 1aafce31c..dd8972a03 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -1,5 +1,6 @@ """LLM implementation supporting both OpenAI and Anthropic models.""" +import logging import os from collections.abc import Sequence from typing import Any, Optional @@ -15,6 +16,10 @@ from langchain_openai import ChatOpenAI from pydantic import Field +from codegen.extensions.langchain.utils.retry import retry_on_rate_limit + +logger = logging.getLogger(__name__) + class LLM(BaseChatModel): """A unified chat model that supports both OpenAI and Anthropic.""" @@ -31,6 +36,10 @@ class LLM(BaseChatModel): max_tokens: Optional[int] = Field(default=None, description="Maximum number of tokens to generate.", ge=1) + max_retries: int = Field(default=3, description="Maximum number of retries for rate limit errors.") + + retry_base_delay: float = Field(default=45.0, description="Base delay in seconds for retry backoff.") + def __init__(self, model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", **kwargs: Any) -> None: """Initialize the LLM. @@ -42,13 +51,15 @@ def __init__(self, model_provider: str = "anthropic", model_name: str = "claude- - top_p: Top-p sampling parameter (0-1) - top_k: Top-k sampling parameter (>= 1) - max_tokens: Maximum number of tokens to generate + - max_retries: Maximum number of retries for rate limit errors + - retry_base_delay: Base delay in seconds for retry backoff """ # Set model provider and name before calling super().__init__ kwargs["model_provider"] = model_provider kwargs["model_name"] = model_name # Filter out unsupported kwargs - supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata"} + supported_kwargs = {"model_provider", "model_name", "temperature", "top_p", "top_k", "max_tokens", "callbacks", "tags", "metadata", "max_retries", "retry_base_delay"} filtered_kwargs = {k: v for k, v in kwargs.items() if k in supported_kwargs} super().__init__(**filtered_kwargs) @@ -96,6 +107,7 @@ def _get_model(self) -> BaseChatModel: msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai" raise ValueError(msg) + @retry_on_rate_limit(max_retries=3, base_delay=45.0) def _generate( self, messages: list[BaseMessage], @@ -114,7 +126,14 @@ def _generate( Returns: ChatResult containing the generated completion """ - return self._model._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + # Use instance-specific retry settings if provided + retry_decorator = retry_on_rate_limit(max_retries=self.max_retries, base_delay=self.retry_base_delay) + + # Apply the retry decorator to the underlying model's _generate method + # This is a bit of a hack, but it allows us to use the decorator with the instance settings + generate_with_retry = retry_decorator(self._model._generate) + + return generate_with_retry(messages, stop=stop, run_manager=run_manager, **kwargs) def bind_tools( self, diff --git a/src/codegen/extensions/langchain/utils/retry.py b/src/codegen/extensions/langchain/utils/retry.py new file mode 100644 index 000000000..5c2f5a3b9 --- /dev/null +++ b/src/codegen/extensions/langchain/utils/retry.py @@ -0,0 +1,75 @@ +"""Retry utilities for handling rate limits and other transient errors.""" + +import asyncio +import functools +import logging +import time +from typing import Any, Callable, TypeVar, cast + +import anthropic +import openai + +logger = logging.getLogger(__name__) + +# Type variable for the decorator +T = TypeVar("T") + + +def retry_on_rate_limit(max_retries: int = 3, base_delay: float = 45.0) -> Callable[[Callable[..., T]], Callable[..., T]]: + """Decorator to retry functions on rate limit errors with exponential backoff. + + Args: + max_retries: Maximum number of retry attempts + base_delay: Base delay in seconds between retries (will be multiplied by 2^retry_count) + + Returns: + Decorated function with retry logic + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + retries = 0 + while True: + try: + return func(*args, **kwargs) + except (openai.RateLimitError, anthropic.RateLimitError) as e: + retries += 1 + if retries > max_retries: + logger.exception(f"Rate limit exceeded after {max_retries} retries. Giving up.") + raise + + # Calculate delay with exponential backoff: base_delay * 2^(retry_count-1) + delay = base_delay * (2 ** (retries - 1)) + logger.warning(f"Rate limit hit. Retrying in {delay:.1f} seconds... (Attempt {retries}/{max_retries})") + time.sleep(delay) + except Exception as e: + # Re-raise other exceptions + raise + + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> T: + retries = 0 + while True: + try: + return await func(*args, **kwargs) + except (openai.RateLimitError, anthropic.RateLimitError) as e: + retries += 1 + if retries > max_retries: + logger.exception(f"Rate limit exceeded after {max_retries} retries. Giving up.") + raise + + # Calculate delay with exponential backoff: base_delay * 2^(retry_count-1) + delay = base_delay * (2 ** (retries - 1)) + logger.warning(f"Rate limit hit. Retrying in {delay:.1f} seconds... (Attempt {retries}/{max_retries})") + await asyncio.sleep(delay) + except Exception as e: + # Re-raise other exceptions + raise + + # Return the appropriate wrapper based on whether the function is async or not + if asyncio.iscoroutinefunction(func): + return cast(Callable[..., T], async_wrapper) + return cast(Callable[..., T], wrapper) + + return decorator