Skip to content

Commit 150dcb9

Browse files
authored
Unified create_retrying for all solvers (#1501)
We're now implementing solvers for new APIs we're calling (Anthropic, Gemini, ...). Each solver was implementing the same logic for backing off and retrying when the API query limit was hit. This PR created a generic create_retrying function, which retries when specific exceptions are raised. These exceptions are passed as arguments. This uses the changes from #1482
1 parent ac44aae commit 150dcb9

File tree

3 files changed

+59
-79
lines changed

3 files changed

+59
-79
lines changed

evals/completion_fns/openai.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import logging
12
from typing import Any, Optional, Union
23

4+
import openai
35
from openai import OpenAI
46

57
from evals.api import CompletionFn, CompletionResult
@@ -12,12 +14,44 @@
1214
Prompt,
1315
)
1416
from evals.record import record_sampling
15-
from evals.utils.api_utils import (
16-
openai_chat_completion_create_retrying,
17-
openai_completion_create_retrying,
17+
from evals.utils.api_utils import create_retrying
18+
19+
OPENAI_TIMEOUT_EXCEPTIONS = (
20+
openai.RateLimitError,
21+
openai.APIConnectionError,
22+
openai.APITimeoutError,
23+
openai.InternalServerError,
1824
)
1925

2026

27+
def openai_completion_create_retrying(client: OpenAI, *args, **kwargs):
28+
"""
29+
Helper function for creating a completion.
30+
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
31+
"""
32+
result = create_retrying(
33+
client.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs
34+
)
35+
if "error" in result:
36+
logging.warning(result)
37+
raise openai.APIError(result["error"])
38+
return result
39+
40+
41+
def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs):
42+
"""
43+
Helper function for creating a completion.
44+
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
45+
"""
46+
result = create_retrying(
47+
client.chat.completions.create, retry_exceptions=OPENAI_TIMEOUT_EXCEPTIONS, *args, **kwargs
48+
)
49+
if "error" in result:
50+
logging.warning(result)
51+
raise openai.APIError(result["error"])
52+
return result
53+
54+
2155
class OpenAIBaseCompletionResult(CompletionResult):
2256
def __init__(self, raw_data: Any, prompt: Any):
2357
self.raw_data = raw_data

evals/solvers/providers/anthropic/anthropic_solver.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
from typing import Any, Optional, Union
22

3-
from evals.solvers.solver import Solver, SolverResult
4-
from evals.task_state import TaskState, Message
5-
from evals.record import record_sampling
6-
from evals.utils.api_utils import request_with_timeout
7-
83
import anthropic
94
from anthropic import Anthropic
105
from anthropic.types import ContentBlock, MessageParam, Usage
11-
import backoff
6+
7+
from evals.record import record_sampling
8+
from evals.solvers.solver import Solver, SolverResult
9+
from evals.task_state import Message, TaskState
10+
from evals.utils.api_utils import create_retrying
1211

1312
oai_to_anthropic_role = {
1413
"system": "user",
1514
"user": "user",
1615
"assistant": "assistant",
1716
}
17+
ANTHROPIC_TIMEOUT_EXCEPTIONS = (
18+
anthropic.RateLimitError,
19+
anthropic.APIConnectionError,
20+
anthropic.APITimeoutError,
21+
anthropic.InternalServerError,
22+
)
1823

1924

2025
class AnthropicSolver(Solver):
@@ -59,9 +64,7 @@ def _solve(self, task_state: TaskState, **kwargs) -> SolverResult:
5964
)
6065

6166
# for logging purposes: prepend the task desc to the orig msgs as a system message
62-
orig_msgs.insert(
63-
0, Message(role="system", content=task_state.task_description).to_dict()
64-
)
67+
orig_msgs.insert(0, Message(role="system", content=task_state.task_description).to_dict())
6568
record_sampling(
6669
prompt=orig_msgs, # original message format, supported by our logviz
6770
sampled=[solver_result.output],
@@ -113,23 +116,14 @@ def _convert_msgs_to_anthropic_format(msgs: list[Message]) -> list[MessageParam]
113116
return alt_msgs
114117

115118

116-
@backoff.on_exception(
117-
wait_gen=backoff.expo,
118-
exception=(
119-
anthropic.RateLimitError,
120-
anthropic.APIConnectionError,
121-
anthropic.APITimeoutError,
122-
anthropic.InternalServerError,
123-
),
124-
max_value=60,
125-
factor=1.5,
126-
)
127119
def anthropic_create_retrying(client: Anthropic, *args, **kwargs):
128120
"""
129121
Helper function for creating a backoff-retry enabled message request.
130122
`args` and `kwargs` match what is accepted by `client.messages.create`.
131123
"""
132-
result = request_with_timeout(client.messages.create, *args, **kwargs)
124+
result = create_retrying(
125+
client.messages.create, retry_exceptions=ANTHROPIC_TIMEOUT_EXCEPTIONS, *args, **kwargs
126+
)
133127
if "error" in result:
134128
raise Exception(result["error"])
135129
return result

evals/utils/api_utils.py

Lines changed: 7 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,22 @@
1-
"""
2-
This file defines various helper functions for interacting with the OpenAI API.
3-
"""
41
import logging
52
import os
63

74
import backoff
8-
import openai
9-
from openai import OpenAI
105

116
EVALS_THREAD_TIMEOUT = float(os.environ.get("EVALS_THREAD_TIMEOUT", "40"))
127
logging.getLogger("httpx").setLevel(logging.WARNING) # suppress "OK" logs from openai API calls
138

149

15-
@backoff.on_exception(
10+
@backoff.on_predicate(
1611
wait_gen=backoff.expo,
17-
exception=(
18-
openai.RateLimitError,
19-
openai.APIConnectionError,
20-
openai.APITimeoutError,
21-
openai.InternalServerError,
22-
),
2312
max_value=60,
2413
factor=1.5,
2514
)
26-
def openai_completion_create_retrying(client: OpenAI, *args, **kwargs):
15+
def create_retrying(func: callable, retry_exceptions: tuple[Exception], *args, **kwargs):
2716
"""
28-
Helper function for creating a completion.
29-
`args` and `kwargs` match what is accepted by `openai.Completion.create`.
17+
Retries given function if one of given exceptions is raised
3018
"""
31-
result = client.completions.create(*args, **kwargs)
32-
if "error" in result:
33-
logging.warning(result)
34-
raise openai.error.APIError(result["error"])
35-
return result
36-
37-
38-
def request_with_timeout(func, *args, timeout=EVALS_THREAD_TIMEOUT, **kwargs):
39-
"""
40-
Function for making a single request within allotted time.
41-
"""
42-
while True:
43-
try:
44-
result = func(*args, timeout=timeout, **kwargs)
45-
return result
46-
except openai.APITimeoutError as e:
47-
continue
48-
49-
50-
@backoff.on_exception(
51-
wait_gen=backoff.expo,
52-
exception=(
53-
openai.RateLimitError,
54-
openai.APIConnectionError,
55-
openai.APITimeoutError,
56-
openai.InternalServerError,
57-
),
58-
max_value=60,
59-
factor=1.5,
60-
)
61-
def openai_chat_completion_create_retrying(client: OpenAI, *args, **kwargs):
62-
"""
63-
Helper function for creating a chat completion.
64-
`args` and `kwargs` match what is accepted by `openai.ChatCompletion.create`.
65-
"""
66-
result = request_with_timeout(client.chat.completions.create, *args, **kwargs)
67-
if "error" in result:
68-
logging.warning(result)
69-
raise openai.error.APIError(result["error"])
70-
return result
19+
try:
20+
return func(*args, **kwargs)
21+
except retry_exceptions:
22+
return False

0 commit comments

Comments
 (0)