Skip to content

Commit fface44

Browse files
feat: add cache retry logic for unsupported providers (#40)
* feat: add cache retry logic for unsupported providers * chore: modify run_step() function signature to accept generator and backoff_wrapper parameters
1 parent a10bf46 commit fface44

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

airtbench/main.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ async def run_step(
202202
challenge: Challenge,
203203
pipeline: rg.ChatPipeline,
204204
kernel: PythonKernel,
205+
generator: rg.Generator = None,
206+
backoff_wrapper=None,
205207
) -> rg.ChatPipeline | None:
206208
# If we are limiting the model to a single code
207209
# execution entry per step, we can safely stop
@@ -247,6 +249,31 @@ async def run_step(
247249
dn.log_metric("max_tokens", 1)
248250
return None
249251

252+
# Handle caching-related errors by disabling cache and retrying
253+
if "cache_control" in str(chat.error) and args.enable_cache:
254+
logger.warning(f"|- Caching not supported by provider, disabling cache and retrying: {chat.error}")
255+
dn.log_metric("cache_unsupported", 1)
256+
# Create new pipeline without caching
257+
retry_pipeline = (
258+
generator.wrap(backoff_wrapper)
259+
.chat(pipeline.chat.messages)
260+
.cache(False)
261+
)
262+
try:
263+
retry_chat = await retry_pipeline.catch(
264+
litellm.exceptions.InternalServerError,
265+
litellm.exceptions.BadRequestError,
266+
litellm.exceptions.Timeout,
267+
litellm.exceptions.ServiceUnavailableError,
268+
litellm.exceptions.APIConnectionError,
269+
on_failed="include",
270+
).run()
271+
if not retry_chat.failed:
272+
logger.info("|- Successfully retried without cache")
273+
return retry_pipeline
274+
except Exception as e:
275+
logger.warning(f"|- Retry without cache also failed: {e}")
276+
250277
logger.warning(f"|- Chat failed: {chat.error}")
251278
dn.log_metric("failed_chats", 1)
252279
pipeline.chat.generated = []
@@ -645,6 +672,8 @@ def on_backoff(details: backoff.types.Details) -> None:
645672
challenge,
646673
pipeline,
647674
kernel,
675+
generator,
676+
backoff_wrapper,
648677
)
649678
else:
650679
logger.warning("|- Max steps reached")

0 commit comments

Comments
 (0)