Skip to content

Commit 36e9e2d

Browse files
committed
Optimize evaluation performance using AsyncOpenAI for concurrent processing
1 parent 946d6c4 commit 36e9e2d

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

evaluation/evals.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import argparse
2525
import logging
26+
import asyncio
2627

2728
import pandas as pd
2829

@@ -37,15 +38,15 @@
3738
)
3839

3940

40-
def evaluate_response(model: str, instructions: str, input: str) -> pd.DataFrame:
41+
async def evaluate_response(model: str, instructions: str, input: str) -> pd.DataFrame:
4142
"""
4243
Test a prompt with a set of test data by scoring each item in the data set
4344
"""
4445

4546
try:
4647
handler = ModelFactory.get_handler(model)
4748

48-
generated_text, token_usage, pricing, duration = handler.handle_request(
49+
generated_text, token_usage, pricing, duration = await handler.handle_request(
4950
instructions, input
5051
)
5152

@@ -145,7 +146,7 @@ def load_csv(file_path: str, required_columns: list) -> pd.DataFrame:
145146
return df
146147

147148

148-
if __name__ == "__main__":
149+
async def main():
149150
# TODO: Add test evaluation argument to run on the first 10 rows of the dataset file
150151

151152
parser = argparse.ArgumentParser()
@@ -177,14 +178,15 @@ def load_csv(file_path: str, required_columns: list) -> pd.DataFrame:
177178
# Bulk model and prompt experimentation: Cross join the experiment and dataset DataFrames
178179
df_in = df_experiment.merge(df_dataset, how="cross")
179180

180-
# Evaluate each row in the input DataFrame
181-
results = []
182-
for index, row in enumerate(df_in.itertuples(index=False)):
183-
result = evaluate_response(row.MODEL, row.INSTRUCTIONS, row.INPUT)
184-
results.append(result)
181+
# Evaluate each row in the input DataFrame concurrently
182+
logging.info(f"Starting evaluation of {len(df_in)} rows")
183+
tasks = [
184+
evaluate_response(row.MODEL, row.INSTRUCTIONS, row.INPUT)
185+
for row in df_in.itertuples(index=False)
186+
]
185187

186-
# TODO: Use tqdm or similar library to show progress bar
187-
logging.info(f"Processed row {index + 1}/{len(df_in)}")
188+
results = await asyncio.gather(*tasks)
189+
logging.info(f"Completed evaluation of {len(results)} rows")
188190

189191
df_evals = pd.concat(results, axis=0, ignore_index=True)
190192

@@ -195,3 +197,7 @@ def load_csv(file_path: str, required_columns: list) -> pd.DataFrame:
195197
df_out.to_csv(args.results, index=False)
196198
logging.info(f"Results saved to {args.results}")
197199
logging.info("Evaluation completed successfully.")
200+
201+
202+
if __name__ == "__main__":
203+
asyncio.run(main())

server/api/services/llm_services.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import logging
88
from abc import ABC, abstractmethod
99

10-
import openai
10+
from openai import AsyncOpenAI
1111

1212

1313
class BaseModelHandler(ABC):
1414
@abstractmethod
15-
def handle_request(
15+
async def handle_request(
1616
self, query: str, context: str
1717
) -> tuple[str, dict[str, int], dict[str, float], float]:
1818
pass
@@ -31,9 +31,9 @@ class GPT4OMiniHandler(BaseModelHandler):
3131
PRICING_DOLLARS_PER_MILLION_TOKENS = {"input": 0.15, "output": 0.60}
3232

3333
def __init__(self) -> None:
34-
self.client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
34+
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
3535

36-
def handle_request(
36+
async def handle_request(
3737
self, query: str, context: str
3838
) -> tuple[str, dict[str, int], dict[str, float], float]:
3939
"""
@@ -46,7 +46,7 @@ def handle_request(
4646
"""
4747
start_time = time.time()
4848
# TODO: Add error handling for API requests and invalid responses
49-
response = self.client.responses.create(
49+
response = await self.client.responses.create(
5050
model=self.MODEL, instructions=query, input=context, temperature=0.0
5151
)
5252
duration = time.time() - start_time
@@ -123,9 +123,9 @@ class GPT41NanoHandler(BaseModelHandler):
123123
"""
124124

125125
def __init__(self) -> None:
126-
self.client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
126+
self.client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
127127

128-
def handle_request(
128+
async def handle_request(
129129
self, query: str, context: str
130130
) -> tuple[str, dict[str, int], dict[str, float], float]:
131131
"""
@@ -144,7 +144,7 @@ def handle_request(
144144
start_time = time.time()
145145
# TODO: Add error handling for API requests and invalid responses
146146

147-
response = self.client.responses.create(
147+
response = await self.client.responses.create(
148148
model=self.MODEL, instructions=query, input=context, temperature=0.0
149149
)
150150
duration = time.time() - start_time

0 commit comments

Comments
 (0)