diff --git a/airtbench/main.py b/airtbench/main.py index ea269df..d7e6f30 100644 --- a/airtbench/main.py +++ b/airtbench/main.py @@ -202,6 +202,8 @@ async def run_step( challenge: Challenge, pipeline: rg.ChatPipeline, kernel: PythonKernel, + generator: rg.Generator = None, + backoff_wrapper=None, ) -> rg.ChatPipeline | None: # If we are limiting the model to a single code # execution entry per step, we can safely stop @@ -247,6 +249,31 @@ async def run_step( dn.log_metric("max_tokens", 1) return None + # Handle caching-related errors by disabling cache and retrying + if "cache_control" in str(chat.error) and args.enable_cache: + logger.warning(f"|- Caching not supported by provider, disabling cache and retrying: {chat.error}") + dn.log_metric("cache_unsupported", 1) + # Create new pipeline without caching + retry_pipeline = ( + generator.wrap(backoff_wrapper) + .chat(pipeline.chat.messages) + .cache(False) + ) + try: + retry_chat = await retry_pipeline.catch( + litellm.exceptions.InternalServerError, + litellm.exceptions.BadRequestError, + litellm.exceptions.Timeout, + litellm.exceptions.ServiceUnavailableError, + litellm.exceptions.APIConnectionError, + on_failed="include", + ).run() + if not retry_chat.failed: + logger.info("|- Successfully retried without cache") + return retry_pipeline + except Exception as e: + logger.warning(f"|- Retry without cache also failed: {e}") + logger.warning(f"|- Chat failed: {chat.error}") dn.log_metric("failed_chats", 1) pipeline.chat.generated = [] @@ -645,6 +672,8 @@ def on_backoff(details: backoff.types.Details) -> None: challenge, pipeline, kernel, + generator, + backoff_wrapper, ) else: logger.warning("|- Max steps reached")