Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 87 additions & 69 deletions codegen-examples/examples/swebench_agent_run/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,93 +17,106 @@
run_agent_modal = modal.Function.from_name(app_name="swebench-agent-run", name="run_agent_modal")


async def process_batch_modal(examples: list[SweBenchExample], batch_size=10):
"""Process a batch of examples concurrently.
async def process_batch_modal(examples: list[SweBenchExample], num_workers=10, max_retries=3):
"""Process a batch of examples concurrently using a queue system.

Args:
examples: List of SweBenchExample objects to process
batch_size: Number of examples to process concurrently.
Default is 50 which provides good parallelization
while staying well within Modal's limits.
num_workers: Number of examples to process concurrently
max_retries: Maximum number of retries for failed requests
"""
results = []

# Process examples in batches
for i in range(0, len(examples), batch_size):
batch = examples[i : i + batch_size]
results = {}
queue = asyncio.Queue()

# Create tasks for this batch
batch_tasks = [run_agent_modal.remote.aio(example) for example in batch]

# Wait for all tasks in this batch to complete
print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})")
# Initialize the queue with (example, attempt) tuples
for example in examples:
await queue.put((example, 0)) # 0 represents first attempt

async def process_example(example, attempt):
try:
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)

# Store results
for example, result in zip(batch, batch_results):
error_info = None

if isinstance(result, Exception):
error_type = type(result).__name__
error_info = {
"error_type": error_type,
"error_message": str(result),
"traceback": traceback.format_exception(type(result), result, result.__traceback__),
}

if isinstance(result, modal.exception.Error):
error_info["modal_error_code"] = getattr(result, "code", None)
error_info["modal_error_details"] = getattr(result, "details", None)

print(f"Error processing {example.instance_id}:")
print(f"Type: {error_type}")
print(f"Message: {str(result)}")
print("Traceback:")
print("".join(error_info["traceback"]))

results.append({"instance_id": example.instance_id, "status": "error", "error_info": error_info})
else:
if result is None:
print(f"Warning: Null result for {example.instance_id}")
results.append({"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}})
else:
results.append(result)
result = await run_agent_modal.remote.aio(example)

if result is None:
print(f"Warning: Null result for {example.instance_id}")
return {"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}
return result

except Exception as e:
print("Batch processing error:")
print(f"Type: {type(e).__name__}")
error_type = type(e).__name__
error_info = {
"error_type": error_type,
"error_message": str(e),
"traceback": traceback.format_exception(type(e), e, e.__traceback__),
}

if isinstance(e, modal.exception.Error):
error_info["modal_error_code"] = getattr(e, "code", None)
error_info["modal_error_details"] = getattr(e, "details", None)

print(f"Error processing {example.instance_id} (attempt {attempt + 1}):")
print(f"Type: {error_type}")
print(f"Message: {str(e)}")
traceback.print_exc()

# Mark all examples in the batch as failed
for example in batch:
results.append(
{
"instance_id": example.instance_id,
"status": "error",
"error_info": {"error_type": type(e).__name__, "error_message": str(e), "traceback": traceback.format_exc(), "batch_failure": True},
}
)
print("Traceback:")
print("".join(error_info["traceback"]))

return results
if attempt < max_retries:
await queue.put((example, attempt + 1))
return None

return {"instance_id": example.instance_id, "status": "error", "error_info": error_info}

async def worker():
while True:
try:
example, attempt = await queue.get()

if example.instance_id in results:
queue.task_done()
continue

result = await process_example(example, attempt)

if result is not None:
results[example.instance_id] = result

queue.task_done()

except Exception as e:
print(f"Worker error: {str(e)}")
traceback.print_exc()
queue.task_done()

# Start workers
workers = [asyncio.create_task(worker()) for _ in range(num_workers)]

# Wait for queue to be fully processed
await queue.join()

# Cancel workers
for w in workers:
w.cancel()

# Wait for all workers to be cancelled
await asyncio.gather(*workers, return_exceptions=True)

# Return results in the same order as input examples
return [results[example.instance_id] for example in examples]


def process_batch_local(examples: list[SweBenchExample], batch_size=10, codebases: dict[str, Codebase] = {}):
def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebases: dict[str, Codebase] = {}):
"""Process a batch of examples synchronously.

Args:
examples: List of SweBenchExample objects to process
batch_size: Number of examples to process in each batch.
num_workers: Number of examples to process in each batch.
Default is 10 to avoid overwhelming the system.
"""
results = []

# Process examples in batches
for i in range(0, len(examples), batch_size):
batch = examples[i : i + batch_size]
print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})")
for i in range(0, len(examples), num_workers):
batch = examples[i : i + num_workers]
print(f"Processing batch {i // num_workers + 1}/{len(examples) // num_workers + 1} (examples {i + 1}-{min(i + num_workers, len(examples))})")

# Process each example in the batch
for example in batch:
Expand Down Expand Up @@ -134,7 +147,9 @@ def process_batch_local(examples: list[SweBenchExample], batch_size=10, codebase
return results


async def run_eval(use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None):
async def run_eval(
use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None, num_workers: int = 5
):
run_id = use_existing_preds or str(uuid.uuid4())
print(f"Run ID: {run_id}")
predictions_dir = PREDS_DNAME / f"results_{run_id}"
Expand Down Expand Up @@ -162,7 +177,7 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in
if local:
results = process_batch_local(examples, codebases=codebases)
else:
results = await process_batch_modal(examples)
results = await process_batch_modal(examples, num_workers=num_workers)

# Save individual results
for result in results:
Expand Down Expand Up @@ -218,9 +233,12 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in
@click.option("--instance-id", help="The instance ID of the example to process.", type=str, default=None)
@click.option("--local", help="Run the evaluation locally.", is_flag=True, default=False)
@click.option("--repo", help="The repo to use.", type=str, default=None)
def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo):
@click.option(
"--num-workers", help="The number of workers to use. This is the number of examples that will be processed concurrently. A large number may lead to rate limiting issues.", type=int, default=5
)
def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo, num_workers):
print(f"Repo: {repo}")
asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo))
asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo, num_workers=num_workers))


if __name__ == "__main__":
Expand Down
Loading