diff --git a/codegen-examples/examples/swebench_agent_run/entry_point.py b/codegen-examples/examples/swebench_agent_run/entry_point.py index bd7b6ec2b..a364aaa19 100644 --- a/codegen-examples/examples/swebench_agent_run/entry_point.py +++ b/codegen-examples/examples/swebench_agent_run/entry_point.py @@ -14,6 +14,6 @@ @app.function(timeout=43200) -async def run_agent_modal(entry: SweBenchExample, run_id: str): +async def run_agent_modal(entry: SweBenchExample, run_id: str, model: str): """Modal function to process a single example from the SWE-bench dataset.""" - return run_agent_on_entry(entry, run_id=run_id) + return run_agent_on_entry(entry, run_id=run_id, model=model) diff --git a/codegen-examples/examples/swebench_agent_run/run_eval.py b/codegen-examples/examples/swebench_agent_run/run_eval.py index 0c0e04f65..9a86135d8 100644 --- a/codegen-examples/examples/swebench_agent_run/run_eval.py +++ b/codegen-examples/examples/swebench_agent_run/run_eval.py @@ -17,7 +17,7 @@ run_agent_modal = modal.Function.from_name(app_name="swebench-agent-run", name="run_agent_modal") -async def process_batch_modal(examples: list[SweBenchExample], run_id: str, num_workers=5, min_workers=1, max_retries=3): +async def process_batch_modal(examples: list[SweBenchExample], run_id: str, model: str, num_workers=5, min_workers=1, max_retries=3): """Process a batch of examples concurrently using a queue system with incremental worker scaling. Args: @@ -110,7 +110,7 @@ async def is_rate_limit_error(error): async def process_example(example, attempt, current_task): try: - result = await run_agent_modal.remote.aio(example, run_id=run_id) + result = await run_agent_modal.remote.aio(example, run_id=run_id, model=model) if result is None: print(f"Warning: Null result for {example.instance_id}") @@ -222,7 +222,7 @@ async def worker(): return [results.get(example.instance_id, {"instance_id": example.instance_id, "status": "missing"}) for example in examples] -def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebases: dict[str, Codebase] = {}, run_id: str | None = None): +def process_batch_local(examples: list[SweBenchExample], model: str, num_workers=5, codebases: dict[str, Codebase] = {}, run_id: str | None = None): """Process a batch of examples synchronously. Args: @@ -242,9 +242,9 @@ def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebase try: # Run the agent locally instead of using modal if codebases and example.instance_id in codebases: - result = run_agent_on_entry(example, codebase=codebases[example.instance_id], run_id=run_id) + result = run_agent_on_entry(example, model=model, codebase=codebases[example.instance_id], run_id=run_id) else: - result = run_agent_on_entry(example, run_id=run_id) + result = run_agent_on_entry(example, model=model, run_id=run_id) results.append(result) except Exception as e: @@ -267,7 +267,15 @@ def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebase async def run_eval( - use_existing_preds: str | None, dataset: str, length: int, instance_id: str | None = None, local: bool = False, codebases: dict[str, Codebase] = {}, repo: str | None = None, num_workers: int = 5 + use_existing_preds: str | None, + dataset: str, + length: int, + instance_id: str | None = None, + local: bool = False, + codebases: dict[str, Codebase] = {}, + repo: str | None = None, + num_workers: int = 5, + model: str = "claude-3-7-sonnet-latest", ): run_id = use_existing_preds or str(uuid.uuid4()) print(f"Run ID: {run_id}") @@ -294,9 +302,9 @@ async def run_eval( # Process all examples in parallel batches if local: - results = process_batch_local(examples, codebases=codebases, run_id=run_id) + results = process_batch_local(examples, model=model, codebases=codebases, run_id=run_id) else: - results = await process_batch_modal(examples, num_workers=num_workers, run_id=run_id) + results = await process_batch_modal(examples, model=model, run_id=run_id, num_workers=num_workers) # Save individual results for result in results: @@ -355,9 +363,11 @@ async def run_eval( @click.option( "--num-workers", help="The number of workers to use. This is the number of examples that will be processed concurrently. A large number may lead to rate limiting issues.", type=int, default=5 ) -def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo, num_workers): +@click.option("--model", help="The model to use.", type=str, default="claude-3-7-sonnet-latest") +def run_eval_command(use_existing_preds, dataset, length, instance_id, local, repo, num_workers, model): print(f"Repo: {repo}") - asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo, num_workers=num_workers)) + print(f"Model: {model}") + asyncio.run(run_eval(use_existing_preds=use_existing_preds, dataset=dataset, length=length, instance_id=instance_id, codebases=None, local=local, repo=repo, num_workers=num_workers, model=model)) if __name__ == "__main__": diff --git a/src/codegen/extensions/swebench/harness.py b/src/codegen/extensions/swebench/harness.py index 5316b109a..90cc9e017 100644 --- a/src/codegen/extensions/swebench/harness.py +++ b/src/codegen/extensions/swebench/harness.py @@ -49,7 +49,7 @@ def show_problems(dataset): print(f"{inst}: {problem}") -def run_agent_on_entry(entry: SweBenchExample, codebase: Codebase | None = None, run_id: str | None = None): +def run_agent_on_entry(entry: SweBenchExample, model: str, codebase: Codebase | None = None, run_id: str | None = None): """Process one `entry` from SWE Bench using the LLM `models` at the given `temperature`. Set `model_name_or_path` in the result json. """ @@ -70,7 +70,7 @@ def run_agent_on_entry(entry: SweBenchExample, codebase: Codebase | None = None, ) codebase = Codebase.from_repo(repo_full_name=entry.repo, commit=base_commit, language="python", config=config) # check out the repo - agent = CodeAgent(codebase=codebase, run_id=run_id, instance_id=instance_id) + agent = CodeAgent(codebase=codebase, run_id=run_id, instance_id=instance_id, model_name=model) pprint.pprint(instance_id) pprint.pprint(gold_files)