Skip to content

Commit b888ce6

Browse files
committed
feat(evaluate): add backoff for file reading
1 parent 13f07c9 commit b888ce6

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

bigcodebench/evaluate.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from warnings import warn
1313
from gradio_client import Client, handle_file
1414

15+
import httpx
1516
import numpy as np
1617
from termcolor import cprint
1718
from tqdm import tqdm
@@ -149,22 +150,27 @@ def evaluate(
149150
result_path = samples.replace(".jsonl", "_eval_results.json")
150151

151152
if not local_execute:
152-
153-
client = Client(remote_execute_api)
154-
results, pass_at_k = client.predict(
155-
split=split,
156-
subset=subset,
157-
samples=handle_file(samples),
158-
pass_k=pass_k,
159-
parallel=parallel,
160-
min_time_limit=min_time_limit,
161-
max_as_limit=max_as_limit,
162-
max_data_limit=max_data_limit,
163-
max_stack_limit=max_stack_limit,
164-
check_gt_only=check_gt_only,
165-
no_gt=no_gt,
166-
api_name="/predict"
167-
)
153+
while True:
154+
try:
155+
client = Client(remote_execute_api)
156+
results, pass_at_k = client.predict(
157+
split=split,
158+
subset=subset,
159+
samples=handle_file(samples),
160+
pass_k=pass_k,
161+
parallel=parallel,
162+
min_time_limit=min_time_limit,
163+
max_as_limit=max_as_limit,
164+
max_data_limit=max_data_limit,
165+
max_stack_limit=max_stack_limit,
166+
check_gt_only=check_gt_only,
167+
no_gt=no_gt,
168+
api_name="/predict"
169+
)
170+
break
171+
except httpx.ReadTimeout:
172+
print("Read timeout error. Retrying in 4s...")
173+
time.sleep(4)
168174
gt_pass_rate = pass_at_k["gt_pass_rate"]
169175
failed_tasks = pass_at_k["failed_tasks"]
170176

@@ -388,3 +394,11 @@ def main():
388394

389395
if __name__ == "__main__":
390396
main()
397+
398+
def main():
399+
from fire import Fire
400+
401+
Fire(evaluate)
402+
403+
if __name__ == "__main__":
404+
main()

0 commit comments

Comments
 (0)