diff --git a/codegen-examples/examples/swebench_agent_run/run_eval.py b/codegen-examples/examples/swebench_agent_run/run_eval.py index b39cede6f..ff349266c 100644 --- a/codegen-examples/examples/swebench_agent_run/run_eval.py +++ b/codegen-examples/examples/swebench_agent_run/run_eval.py @@ -17,93 +17,106 @@ run_agent_modal = modal.Function.from_name(app_name="swebench-agent-run", name="run_agent_modal") -async def process_batch_modal(examples: list[SweBenchExample], batch_size=10): - """Process a batch of examples concurrently. +async def process_batch_modal(examples: list[SweBenchExample], num_workers=10, max_retries=3): + """Process a batch of examples concurrently using a queue system. Args: examples: List of SweBenchExample objects to process - batch_size: Number of examples to process concurrently. - Default is 50 which provides good parallelization - while staying well within Modal's limits. + num_workers: Number of examples to process concurrently + max_retries: Maximum number of retries for failed requests """ - results = [] - - # Process examples in batches - for i in range(0, len(examples), batch_size): - batch = examples[i : i + batch_size] + results = {} + queue = asyncio.Queue() - # Create tasks for this batch - batch_tasks = [run_agent_modal.remote.aio(example) for example in batch] - - # Wait for all tasks in this batch to complete - print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})") + # Initialize the queue with (example, attempt) tuples + for example in examples: + await queue.put((example, 0)) # 0 represents first attempt + async def process_example(example, attempt): try: - batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) - - # Store results - for example, result in zip(batch, batch_results): - error_info = None - - if isinstance(result, Exception): - error_type = type(result).__name__ - error_info = { - "error_type": error_type, - "error_message": str(result), - "traceback": traceback.format_exception(type(result), result, result.__traceback__), - } - - if isinstance(result, modal.exception.Error): - error_info["modal_error_code"] = getattr(result, "code", None) - error_info["modal_error_details"] = getattr(result, "details", None) - - print(f"Error processing {example.instance_id}:") - print(f"Type: {error_type}") - print(f"Message: {str(result)}") - print("Traceback:") - print("".join(error_info["traceback"])) - - results.append({"instance_id": example.instance_id, "status": "error", "error_info": error_info}) - else: - if result is None: - print(f"Warning: Null result for {example.instance_id}") - results.append({"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}) - else: - results.append(result) + result = await run_agent_modal.remote.aio(example) + + if result is None: + print(f"Warning: Null result for {example.instance_id}") + return {"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}} + return result except Exception as e: - print("Batch processing error:") - print(f"Type: {type(e).__name__}") + error_type = type(e).__name__ + error_info = { + "error_type": error_type, + "error_message": str(e), + "traceback": traceback.format_exception(type(e), e, e.__traceback__), + } + + if isinstance(e, modal.exception.Error): + error_info["modal_error_code"] = getattr(e, "code", None) + error_info["modal_error_details"] = getattr(e, "details", None) + + print(f"Error processing {example.instance_id} (attempt {attempt + 1}):") + print(f"Type: {error_type}") print(f"Message: {str(e)}") - traceback.print_exc() - - # Mark all examples in the batch as failed - for example in batch: - results.append( - { - "instance_id": example.instance_id, - "status": "error", - "error_info": {"error_type": type(e).__name__, "error_message": str(e), "traceback": traceback.format_exc(), "batch_failure": True}, - } - ) + print("Traceback:") + print("".join(error_info["traceback"])) - return results + if attempt < max_retries: + await queue.put((example, attempt + 1)) + return None + + return {"instance_id": example.instance_id, "status": "error", "error_info": error_info} + + async def worker(): + while True: + try: + example, attempt = await queue.get() + + if example.instance_id in results: + queue.task_done() + continue + + result = await process_example(example, attempt) + + if result is not None: + results[example.instance_id] = result + + queue.task_done() + + except Exception as e: + print(f"Worker error: {str(e)}") + traceback.print_exc() + queue.task_done() + + # Start workers + workers = [asyncio.create_task(worker()) for _ in range(num_workers)] + + # Wait for queue to be fully processed + await queue.join() + + # Cancel workers + for w in workers: + w.cancel() + + # Wait for all workers to be cancelled + await asyncio.gather(*workers, return_exceptions=True) + + # Return results in the same order as input examples + return [results[example.instance_id] for example in examples] -def process_batch_local(examples: list[SweBenchExample], batch_size=10, codebases: dict[str, Codebase] = {}): +def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebases: dict[str, Codebase] = {}): """Process a batch of examples synchronously. Args: examples: List of SweBenchExample objects to process - batch_size: Number of examples to process in each batch. + num_workers: Number of examples to process in each batch. Default is 10 to avoid overwhelming the system. """ results = [] # Process examples in batches - for i in range(0, len(examples), batch_size): - batch = examples[i : i + batch_size] - print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})") + for i in range(0, len(examples), num_workers): + batch = examples[i : i + num_workers] + print(f"Processing batch {i // num_workers + 1}/{len(examples) // num_workers + 1} (examples {i + 1}-{min(i + num_workers, len(examples))})") # Process each example in the batch for example in batch: @@ -134,7 +147,9 @@ def process_batch_local(examples: list[SweBenchExample], batch_size=10, codebase return results -async def run_eval(use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None): +async def run_eval( + use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None, num_workers: int = 5 +): run_id = use_existing_preds or str(uuid.uuid4()) print(f"Run ID: {run_id}") predictions_dir = PREDS_DNAME / f"results_{run_id}" @@ -162,7 +177,7 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in if local: results = process_batch_local(examples, codebases=codebases) else: - results = await process_batch_modal(examples) + results = await process_batch_modal(examples, num_workers=num_workers) # Save individual results for result in results: @@ -218,9 +233,12 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in @click.option("--instance-id", help="The instance ID of the example to process.", type=str, default=None) @click.option("--local", help="Run the evaluation locally.", is_flag=True, default=False) @click.option("--repo", help="The repo to use.", type=str, default=None) -def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo): +@click.option( + "--num-workers", help="The number of workers to use. This is the number of examples that will be processed concurrently. A large number may lead to rate limiting issues.", type=int, default=5 +) +def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo, num_workers): print(f"Repo: {repo}") - asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo)) + asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo, num_workers=num_workers)) if __name__ == "__main__":