diff --git a/codegen-examples/examples/swebench_agent_run/run_eval.py b/codegen-examples/examples/swebench_agent_run/run_eval.py index e2c844254..0c2132694 100644 --- a/codegen-examples/examples/swebench_agent_run/run_eval.py +++ b/codegen-examples/examples/swebench_agent_run/run_eval.py @@ -6,7 +6,7 @@ import modal import click from datetime import datetime -from codegen.extensions.swebench.utils import SWEBenchDataset, SweBenchExample, get_swe_bench_example, get_swe_bench_examples +from codegen.extensions.swebench.utils import SWEBenchDataset, SweBenchExample, get_swe_bench_examples from codegen.extensions.swebench.report import generate_report PREDS_DNAME = Path(__file__).parent / "predictions" @@ -92,10 +92,7 @@ async def run_eval(use_existing_preds: str | None, dataset: str, length: int, in run_id = use_existing_preds or str(uuid.uuid4()) predictions_dir = PREDS_DNAME / f"results_{run_id}" dataset = SWEBenchDataset(dataset) - if instance_id: - examples = [get_swe_bench_example(instance_id, dataset=dataset)] - else: - examples = get_swe_bench_examples(dataset=dataset, length=length) + examples = get_swe_bench_examples(dataset=dataset, length=length, instance_id=instance_id) try: if use_existing_preds is None: diff --git a/pyproject.toml b/pyproject.toml index 5f1fbe6b1..a72f9ca8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dependencies = [ "httpx>=0.28.1", "docker>=6.1.3", "urllib3>=2.0.0", + "datasets", ] license = { text = "Apache-2.0" } diff --git a/src/codegen/extensions/swebench/utils.py b/src/codegen/extensions/swebench/utils.py index 91e42c464..05b4c4617 100644 --- a/src/codegen/extensions/swebench/utils.py +++ b/src/codegen/extensions/swebench/utils.py @@ -5,7 +5,10 @@ from pprint import pprint from typing import Literal, Optional -import requests +from datasets import load_dataset + +# Add constant for cache directory +CACHE_DIR = Path.home() / ".cache" / "swebench" class SWEBenchDataset(Enum): @@ -64,93 +67,66 @@ def load_predictions(paths): return predictions -def get_swe_bench_examples(dataset: SWEBenchDataset = SWEBenchDataset.LITE, split: Literal["train", "dev", "test"] = "test", offset: int = 0, length: int = 100) -> list[SweBenchExample]: - """Fetch examples from the SWE-bench dataset. +def get_swe_bench_examples( + dataset: SWEBenchDataset = SWEBenchDataset.LITE, + split: Literal["train", "dev", "test"] = "test", + offset: int = 0, + length: int = 100, + instance_id: str | None = None, +) -> list[SweBenchExample]: + """Fetch examples from the SWE-bench dataset using the datasets library. + + Args: + dataset: The dataset to use (LITE, FULL, or VERIFIED) + split: The dataset split to use + offset: Starting index for examples + length: Number of examples to fetch Returns: List of SweBenchExample objects - - Raises: - requests.RequestException: If the API request fails """ - url = "https://datasets-server.huggingface.co/rows" - params = { - "dataset": dataset.value, - "config": "default", - "split": split, - "offset": offset, - "length": length, - } - - response = requests.get(url, params=params) - response.raise_for_status() - data = response.json() - + # Ensure cache directory exists + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + # Load the dataset with caching enabled + dataset_name = dataset.value + swe_bench_dataset = load_dataset(dataset_name, cache_dir=str(CACHE_DIR), download_mode="reuse_dataset_if_exists") + + # Get the requested split + split_data = swe_bench_dataset[split] + + # Apply offset and length + if instance_id: + offset = 0 + end_idx = len(split_data) + else: + end_idx = min(offset + length, len(split_data)) + if offset >= len(split_data): + return [] + + # Use the select method instead of slicing + # This ensures we get dictionary-like objects + selected_rows = split_data.select(range(offset, end_idx)) + + # Convert to SweBenchExample objects examples = [] - for row in data["rows"]: + for row in selected_rows: + if instance_id and row["instance_id"] != instance_id: + continue example = SweBenchExample( - repo=row["row"]["repo"], - instance_id=row["row"]["instance_id"], - base_commit=row["row"]["base_commit"], - patch=row["row"]["patch"], - test_patch=row["row"]["test_patch"], - problem_statement=row["row"]["problem_statement"], - hints_text=row["row"].get("hints_text"), - created_at=row["row"]["created_at"], - version=row["row"]["version"], - fail_to_pass=row["row"]["FAIL_TO_PASS"], - pass_to_pass=row["row"].get("PASS_TO_PASS"), - environment_setup_commit=row["row"].get("environment_setup_commit"), + repo=row["repo"], + instance_id=row["instance_id"], + base_commit=row["base_commit"], + patch=row["patch"], + test_patch=row["test_patch"], + problem_statement=row["problem_statement"], + hints_text=row.get("hints_text"), + created_at=row["created_at"], + version=row["version"], + fail_to_pass=row["FAIL_TO_PASS"], + pass_to_pass=row.get("PASS_TO_PASS"), + environment_setup_commit=row.get("environment_setup_commit"), ) examples.append(example) return examples - - -def get_swe_bench_example( - instance_id: str, - dataset: SWEBenchDataset = SWEBenchDataset.LITE, -) -> SweBenchExample: - """Fetch a single example from the SWE-bench dataset by its instance ID. - - Args: - instance_id: The unique identifier of the example to fetch - - Returns: - SweBenchExample object - - Raises: - ValueError: If no example found with the given ID - requests.RequestException: If the API request fails - """ - url = "https://datasets-server.huggingface.co/filter" - params = { - "dataset": dataset.value, - "config": "default", - "split": "dev", - "where": f"instance_id='{instance_id}'", - } - - response = requests.get(url, params=params) - response.raise_for_status() - data = response.json() - - if not data["rows"]: - msg = f"No example found with instance_id: {instance_id}" - raise ValueError(msg) - - row = data["rows"][0]["row"] - return SweBenchExample( - repo=row["repo"], - instance_id=row["instance_id"], - base_commit=row["base_commit"], - patch=row["patch"], - test_patch=row["test_patch"], - problem_statement=row["problem_statement"], - hints_text=row.get("hints_text"), - created_at=row["created_at"], - version=row["version"], - fail_to_pass=row["FAIL_TO_PASS"], - pass_to_pass=row.get("PASS_TO_PASS"), - environment_setup_commit=row.get("environment_setup_commit"), - )