-
Notifications
You must be signed in to change notification settings - Fork 144
Add ComputeEval Dataset Support #1124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
98e9dbf
e9ce660
d1f8e8a
1510216
572f770
b2308f7
0832f7d
eb4969d
0211bea
980a574
52059dd
920e174
1530926
cd88e99
dc9c920
9e1a951
ef71245
a676e34
417ef07
a90cc44
8aba153
c576b91
7225f97
1cbe0b1
4b49e26
5df7599
7f405cb
73624d6
94f2208
544199b
b7cb62c
7457060
a7aa94a
9a4f608
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| EVAL_SPLIT = "eval" | ||
| DATASET_GROUP = "code" | ||
| METRICS_TYPE = "compute-eval" | ||
| GENERATION_MODULE = "nemo_skills.inference.eval.compute_eval" | ||
| GENERATION_ARGS = "++prompt_config=compute-eval/baseline ++eval_type=compute-eval" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import argparse | ||
| import json | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| from datasets import load_dataset | ||
|
|
||
| _CONTEXT_FILES_BLOCK_TEMPLATE = """ | ||
| --- file: {path} | ||
| ```{fence} | ||
| {content} | ||
| ``` | ||
| """ | ||
|
|
||
|
|
||
| def _fence_for_path(path: str) -> str: | ||
| p = path.lower() | ||
| if p.endswith((".cu", ".cuh")): | ||
| return "cuda" | ||
| if p.endswith((".cc", ".cpp", ".cxx")): | ||
| return "cpp" | ||
| if p.endswith(".c"): | ||
| return "c" | ||
| if p.endswith(".h") or p.endswith(".hpp"): | ||
| return "h" | ||
| # Default to plaintext if unknown | ||
| return "" | ||
|
|
||
|
|
||
| def _format_context_files_block(context_files: list[dict[str, str]]) -> str: | ||
| blocks: list[str] = [] | ||
| for source in context_files: | ||
| if "path" not in source or "content" not in source: | ||
| continue | ||
|
|
||
| fence = _fence_for_path(source["path"]) | ||
| blocks.append( | ||
| _CONTEXT_FILES_BLOCK_TEMPLATE.format(path=source["path"], fence=fence, content=source["content"]) | ||
| ) | ||
| return "".join(blocks) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser(description="Download and prepare nvidia/compute-eval dataset") | ||
| parser.add_argument( | ||
| "--release", | ||
| type=str, | ||
| default=None, | ||
| help="Release to download (e.g., '2025-1', '2025-2'). If not specified, downloads default release.", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| token = os.getenv("HF_TOKEN", None) | ||
| if not token: | ||
| print("Error: HF_TOKEN environment variable not set. Please set it to access the dataset.") | ||
| exit(1) | ||
|
|
||
| dataset = load_dataset("nvidia/compute-eval", args.release, token=token) | ||
| data_dir = Path(__file__).absolute().parent | ||
| data_dir.mkdir(exist_ok=True) | ||
|
|
||
| with open(data_dir / "eval.jsonl", "wt", encoding="utf-8") as f: | ||
| for item in dataset["eval"]: | ||
| record = { | ||
| "problem": item, | ||
| "task_id": item["task_id"], | ||
| "problem_prompt": item["prompt"], | ||
| "build_command": item["build_command"], | ||
| "context_files_block": _format_context_files_block(item["context_files"]), | ||
| } | ||
|
|
||
| # Dumping using default=str to handle datetime serialization from the problem records | ||
| f.write(json.dumps(record, default=str) + "\n") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,85 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Annotated, Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from compute_eval.data.data_model import CudaCppProblem, CudaPythonProblem, FileSolution, PatchSolution | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from compute_eval.execution import evaluate_solution | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from compute_eval.utils.eval_utils import get_nvcc_version, parse_semver | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from pydantic import Field, TypeAdapter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_skills.evaluation.evaluator import BaseEvaluator | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_skills.utils import get_logger_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _LOG = logging.getLogger(get_logger_name(__file__)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _PROBLEM_ADAPTER = TypeAdapter(Annotated[CudaCppProblem | CudaPythonProblem, Field(discriminator="type")]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _SOLUTION_ADAPTER = TypeAdapter(Annotated[FileSolution | PatchSolution, Field(discriminator="type")]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class ComputeEvalEvaluator(BaseEvaluator): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _installed_ctk_major: int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _installed_ctk_minor: int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, config: dict, num_parallel_requests=10): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__(config, num_parallel_requests) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nvcc_version = get_nvcc_version() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not nvcc_version: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "NVCC not found. Please ensure that the CUDA Toolkit is installed and nvcc is in your PATH." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._installed_ctk_major, self._installed_ctk_minor, _ = parse_semver(nvcc_version) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def eval_single(self, data_point: dict[str, Any]) -> dict[str, Any]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # noinspection PyBroadException | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+46
to
+47
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| problem = _PROBLEM_ADAPTER.validate_python(data_point["problem"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| solution = _SOLUTION_ADAPTER.validate_python(data_point["solution"]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+48
to
+49
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing error handling for missing keys if |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| graded = await asyncio.to_thread( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| evaluate_solution, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| installed_ctk_major=self._installed_ctk_major, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| installed_ctk_minor=self._installed_ctk_minor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| problem=problem, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| solution=solution, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "passed": graded.passed, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "skipped": graded.skipped, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "elapsed_time": graded.elapsed_time, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "build_output": graded.build_output, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "test_output": graded.test_output, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except KeyError as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _LOG.error(f"Missing required field in data_point: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "passed": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "skipped": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "elapsed_time": 0.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "build_output": "", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "test_output": "", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "error": f"Missing required field: {e}", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _LOG.error(f"Error during evaluation: {e}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "passed": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "skipped": False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "elapsed_time": 0.0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "build_output": "", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "test_output": "", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "error": str(e), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
76
to
85
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. catching all exceptions masks critical errors and returns inconsistent fields - successful evaluation returns
Suggested change
Comment on lines
76
to
85
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. overly broad exception handling loses error context catching all exceptions with
consider catching specific exception types or at least including exception type in the error field |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import logging | ||
| import sys | ||
|
|
||
| import hydra | ||
| from compute_eval.data.data_model import FileSolution | ||
|
|
||
| # noinspection PyProtectedMember | ||
| from compute_eval.generate_completions import _parse_solution | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: accessing protected member If Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. accessing protected member
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uses protected member Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Comment on lines
+20
to
+21
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. accessing protected member
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using protected member accessing
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. accessing protected member
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. importing private function Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
|
|
||
| from nemo_skills.inference.generate import GenerateSolutionsConfig, GenerationTask | ||
| from nemo_skills.inference.model import server_params | ||
| from nemo_skills.utils import ( | ||
| get_help_message, | ||
| get_logger_name, | ||
| setup_logging, | ||
| ) | ||
|
|
||
| _LOG = logging.getLogger(get_logger_name(__file__)) | ||
|
|
||
|
|
||
| class ComputeEvalGenerationTask(GenerationTask): | ||
| def __init__(self, cfg: GenerateSolutionsConfig): | ||
| super().__init__(cfg) | ||
|
|
||
| async def process_single_datapoint(self, data_point, data): | ||
| res = await super().process_single_datapoint(data_point, data) | ||
| try: | ||
| solution = FileSolution( | ||
| task_id=data_point["task_id"], | ||
| files=_parse_solution(res["generation"]), | ||
| ) | ||
| return { | ||
| "solution": solution.model_dump(), | ||
| "generation": res["generation"], | ||
| } | ||
| except KeyError as e: | ||
| _LOG.error(f"Missing required field: {e}") | ||
| raise | ||
| except Exception as e: | ||
| _LOG.error(f"Failed to parse solution: {e}") | ||
| raise | ||
|
|
||
|
|
||
| GENERATION_TASK_CLASS = ComputeEvalGenerationTask | ||
|
|
||
|
|
||
| @hydra.main(version_base=None, config_name="base_generation_config") | ||
| def run_compute_eval(cfg: GenerateSolutionsConfig): | ||
| _LOG.info("Config used: %s", cfg) | ||
|
|
||
| task = ComputeEvalGenerationTask(cfg) | ||
| task.generate() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| if "--help" in sys.argv or "-h" in sys.argv: | ||
| print(get_help_message(GenerateSolutionsConfig, server_params=server_params())) | ||
| else: | ||
| setup_logging() | ||
| run_compute_eval() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no validation that
context_fileslist items have requiredpathandcontentkeys - will raise KeyError if dataset schema changes