|
12 | 12 | from warnings import warn |
13 | 13 | from gradio_client import Client, handle_file |
14 | 14 |
|
| 15 | +import httpx |
15 | 16 | import numpy as np |
16 | 17 | from termcolor import cprint |
17 | 18 | from tqdm import tqdm |
@@ -149,22 +150,27 @@ def evaluate( |
149 | 150 | result_path = samples.replace(".jsonl", "_eval_results.json") |
150 | 151 |
|
151 | 152 | if not local_execute: |
152 | | - |
153 | | - client = Client(remote_execute_api) |
154 | | - results, pass_at_k = client.predict( |
155 | | - split=split, |
156 | | - subset=subset, |
157 | | - samples=handle_file(samples), |
158 | | - pass_k=pass_k, |
159 | | - parallel=parallel, |
160 | | - min_time_limit=min_time_limit, |
161 | | - max_as_limit=max_as_limit, |
162 | | - max_data_limit=max_data_limit, |
163 | | - max_stack_limit=max_stack_limit, |
164 | | - check_gt_only=check_gt_only, |
165 | | - no_gt=no_gt, |
166 | | - api_name="/predict" |
167 | | - ) |
| 153 | + while True: |
| 154 | + try: |
| 155 | + client = Client(remote_execute_api) |
| 156 | + results, pass_at_k = client.predict( |
| 157 | + split=split, |
| 158 | + subset=subset, |
| 159 | + samples=handle_file(samples), |
| 160 | + pass_k=pass_k, |
| 161 | + parallel=parallel, |
| 162 | + min_time_limit=min_time_limit, |
| 163 | + max_as_limit=max_as_limit, |
| 164 | + max_data_limit=max_data_limit, |
| 165 | + max_stack_limit=max_stack_limit, |
| 166 | + check_gt_only=check_gt_only, |
| 167 | + no_gt=no_gt, |
| 168 | + api_name="/predict" |
| 169 | + ) |
| 170 | + break |
| 171 | + except httpx.ReadTimeout: |
| 172 | + print("Read timeout error. Retrying in 4s...") |
| 173 | + time.sleep(4) |
168 | 174 | gt_pass_rate = pass_at_k["gt_pass_rate"] |
169 | 175 | failed_tasks = pass_at_k["failed_tasks"] |
170 | 176 |
|
@@ -388,3 +394,11 @@ def main(): |
388 | 394 |
|
389 | 395 | if __name__ == "__main__": |
390 | 396 | main() |
| 397 | + |
| 398 | +def main(): |
| 399 | + from fire import Fire |
| 400 | + |
| 401 | + Fire(evaluate) |
| 402 | + |
| 403 | +if __name__ == "__main__": |
| 404 | + main() |
0 commit comments