-
Notifications
You must be signed in to change notification settings - Fork 139
Add arena-hard v2 #1205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add arena-hard v2 #1205
Changes from all commits
9ab83b6
d55bf4a
9509a8c
9f8a345
6b4f9dc
ca1ad87
ddd2da4
b26c858
443c6aa
0e8b35f
a6886fe
d4aff02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| # settings that define how evaluation should be done by default (all can be changed from cmdline) | ||
| DATASET_GROUP = "chat" | ||
| METRICS_TYPE = "arena" | ||
| # using judgement directly in metrics, no need for special evaluation | ||
| GENERATION_ARGS = "++prompt_config=generic/default" | ||
|
|
||
| JUDGE_PIPELINE_ARGS = { | ||
| "generation_module": "nemo_skills.inference.eval.arena_judge", | ||
| "model": "gpt-4.1", | ||
| "server_type": "openai", | ||
| "server_address": "https://api.openai.com/v1", | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import json | ||
| import urllib.request | ||
| from pathlib import Path | ||
|
|
||
| URL_QUESTIONS = "https://raw.githubusercontent.com/lmarena/arena-hard-auto/main/data/arena-hard-v2.0/question.jsonl" | ||
| # Category-specific baselines as per official arena-hard-auto implementation | ||
| URL_BASELINE_HARD_PROMPT = "https://raw.githubusercontent.com/lmarena/arena-hard-auto/main/data/arena-hard-v2.0/model_answer/o3-mini-2025-01-31.jsonl" | ||
| URL_BASELINE_CREATIVE_WRITING = "https://raw.githubusercontent.com/lmarena/arena-hard-auto/main/data/arena-hard-v2.0/model_answer/gemini-2.0-flash-001.jsonl" | ||
|
|
||
| # Mapping of category to baseline URL | ||
| CATEGORY_BASELINES = { | ||
| "hard_prompt": URL_BASELINE_HARD_PROMPT, | ||
| "creative_writing": URL_BASELINE_CREATIVE_WRITING, | ||
| } | ||
|
|
||
|
|
||
| def extract_answer_text(data): | ||
| """Extract the answer text from the baseline model's response format.""" | ||
| messages = data["messages"] | ||
| for msg in messages: | ||
| if msg["role"] == "assistant": | ||
| content = msg["content"] | ||
| return content["answer"] if isinstance(content, dict) else content | ||
| raise ValueError("No assistant message found in the data.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| data_dir = Path(__file__).absolute().parent | ||
| data_dir.mkdir(exist_ok=True) | ||
| questions_file = str(data_dir / "question.jsonl") | ||
| output_file = str(data_dir / "test.jsonl") | ||
|
|
||
| # Download questions | ||
| urllib.request.urlretrieve(URL_QUESTIONS, questions_file) | ||
|
|
||
| # Download and process all baseline files | ||
| baseline_answers = {} | ||
| for category, url in CATEGORY_BASELINES.items(): | ||
| baseline_file = str(data_dir / f"baseline_{category}.jsonl") | ||
| urllib.request.urlretrieve(url, baseline_file) | ||
|
|
||
| with open(baseline_file, "rt", encoding="utf-8") as fin: | ||
| for line in fin: | ||
| data = json.loads(line) | ||
| uid = data["uid"] | ||
| if uid not in baseline_answers: | ||
| baseline_answers[uid] = {} | ||
| baseline_answers[uid][category] = extract_answer_text(data) | ||
|
|
||
| # Create test.jsonl with category-specific baseline answers | ||
| with open(questions_file, "rt", encoding="utf-8") as fin, open(output_file, "wt", encoding="utf-8") as fout: | ||
| for line in fin: | ||
| data = json.loads(line) | ||
| data["question"] = data.pop("prompt") | ||
| category = data["category"] | ||
| data["baseline_answer"] = baseline_answers[data["uid"]][category] | ||
| fout.write(json.dumps(data) + "\n") | ||
|
Comment on lines
+64
to
+71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Given v2 explicitly has multiple baselines by category, it would be safer to fail with a clearer error that prints the missing |
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import re | ||||||||||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from nemo_skills.evaluation.metrics.base import BaseMetrics | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -51,6 +52,11 @@ def update(self, predictions): | |||||||||||||||||||||||||||
| super().update(predictions) | ||||||||||||||||||||||||||||
| self.scores.append([]) | ||||||||||||||||||||||||||||
| self.agg_mode = f"pass@{len(predictions)}" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Track category for per-category scoring (defaults to None for v1 compatibility) | ||||||||||||||||||||||||||||
| category = predictions[0].get("category") | ||||||||||||||||||||||||||||
| self.categories.append(category) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
Comment on lines
52
to
+59
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This shows up when evaluating older/partially-prepared datasets or when a pipeline forgets to propagate A safer approach is to only create per-category buckets for non- |
||||||||||||||||||||||||||||
| if len(predictions) > 1: | ||||||||||||||||||||||||||||
| judge_scores = [self._get_judge_score(elem["judgement-gen-base"]) for elem in predictions] | ||||||||||||||||||||||||||||
| # adding the best score out of all the generations | ||||||||||||||||||||||||||||
|
|
@@ -86,16 +92,34 @@ def update(self, predictions): | |||||||||||||||||||||||||||
| def get_metrics(self): | ||||||||||||||||||||||||||||
| from nemo_skills.evaluation.evaluator.arena import get_aggregate_score | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| metrics = {"num_entries": self.total} | ||||||||||||||||||||||||||||
| metrics.update(get_aggregate_score(self.scores)) | ||||||||||||||||||||||||||||
| metrics_dict = {self.agg_mode: metrics} | ||||||||||||||||||||||||||||
| self.update_common_metrics(metrics_dict[self.agg_mode]) | ||||||||||||||||||||||||||||
| metrics_dict = {} | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Compute overall metrics | ||||||||||||||||||||||||||||
| overall_metrics = {"num_entries": self.total} | ||||||||||||||||||||||||||||
| overall_metrics.update(get_aggregate_score(self.scores)) | ||||||||||||||||||||||||||||
| self.update_common_metrics(overall_metrics) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Group scores by category for per-category metrics | ||||||||||||||||||||||||||||
| category_scores = defaultdict(list) | ||||||||||||||||||||||||||||
| for score, category in zip(self.scores, self.categories, strict=True): | ||||||||||||||||||||||||||||
| category_scores[category].append(score) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # If we have multiple categories, compute per-category metrics | ||||||||||||||||||||||||||||
| unique_categories = set(self.categories) | ||||||||||||||||||||||||||||
| if len(unique_categories) > 1: | ||||||||||||||||||||||||||||
| for category, scores in category_scores.items(): | ||||||||||||||||||||||||||||
| cat_metrics = {"num_entries": len(scores)} | ||||||||||||||||||||||||||||
| cat_metrics.update(get_aggregate_score(scores)) | ||||||||||||||||||||||||||||
| overall_metrics[f"category_{category}"] = cat_metrics | ||||||||||||||||||||||||||||
|
Comment on lines
+109
to
+113
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when mixed data includes If arena-hard-v1 data (no category) is mixed with v2 data (with categories), Consider only creating per-category buckets for non-
Suggested change
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| metrics_dict[self.agg_mode] = overall_metrics | ||||||||||||||||||||||||||||
| # arena metrics have their own confidence estimation, so not doing std metrics here | ||||||||||||||||||||||||||||
| return metrics_dict | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| def reset(self): | ||||||||||||||||||||||||||||
| super().reset() | ||||||||||||||||||||||||||||
| self.scores = [] # list of lists | ||||||||||||||||||||||||||||
| self.categories = [] # list of category strings | ||||||||||||||||||||||||||||
| self.lengths = 0 | ||||||||||||||||||||||||||||
| # TODO: the class should support pass@k, but this forces it to report as pass@1. | ||||||||||||||||||||||||||||
| # There is some error here for k>1 | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,6 +26,8 @@ | |||||||||||||||||||||
| InferenceConfig, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| from nemo_skills.inference.model import server_params | ||||||||||||||||||||||
| from nemo_skills.inference.model.base import EndpointType | ||||||||||||||||||||||
| from nemo_skills.prompt.utils import get_prompt | ||||||||||||||||||||||
| from nemo_skills.utils import ( | ||||||||||||||||||||||
| get_help_message, | ||||||||||||||||||||||
| get_logger_name, | ||||||||||||||||||||||
|
|
@@ -48,9 +50,15 @@ class ArenaJudgeConfig(GenerationTaskConfig): | |||||||||||||||||||||
| server: dict = field(default_factory=dict) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Override the default Generation config here | ||||||||||||||||||||||
| # prompt_config is used as the default for any category not explicitly mapped below | ||||||||||||||||||||||
| prompt_config: str = "judge/arena" | ||||||||||||||||||||||
| generation_key: str = "judgement" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Category-specific prompt config overrides (arena-hard-v2 uses different prompts per category) | ||||||||||||||||||||||
| # Set to None to use the default prompt_config for that category | ||||||||||||||||||||||
| # creative_writing uses a prompt that doesn't ask the judge to generate its own answer first | ||||||||||||||||||||||
| prompt_config_creative: str = "judge/arena_creative" | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| cs = hydra.core.config_store.ConfigStore.instance() | ||||||||||||||||||||||
| cs.store(name="base_arena_judge_config", node=ArenaJudgeConfig) | ||||||||||||||||||||||
|
|
@@ -60,6 +68,63 @@ class ArenaJudgeTask(GenerationTask): | |||||||||||||||||||||
| def __init__(self, cfg: ArenaJudgeConfig): | ||||||||||||||||||||||
| super().__init__(cfg) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def setup_prompt(self): | ||||||||||||||||||||||
| if self.cfg.prompt_format == "openai": | ||||||||||||||||||||||
| return None | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Comment on lines
+71
to
+74
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If category-specific judging is required for v2, this likely needs an equivalent branch for the OpenAI prompt path (e.g., selecting different message templates/configs per category) or explicitly disallowing openai format for v2. |
||||||||||||||||||||||
| # Load the default prompt (used for most categories including hard_prompt, arena-hard-v0.1, etc.) | ||||||||||||||||||||||
| default_prompt = get_prompt( | ||||||||||||||||||||||
| prompt_config=self.cfg.prompt_config, | ||||||||||||||||||||||
| tokenizer=self.tokenizer, | ||||||||||||||||||||||
| code_tags=self.cfg.code_tags, | ||||||||||||||||||||||
| examples_type=self.cfg.examples_type, | ||||||||||||||||||||||
| system_message=self.cfg.system_message, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Load category-specific prompt overrides | ||||||||||||||||||||||
| self.category_prompts = {} | ||||||||||||||||||||||
| if self.cfg.prompt_config_creative: | ||||||||||||||||||||||
| self.category_prompts["creative_writing"] = get_prompt( | ||||||||||||||||||||||
| prompt_config=self.cfg.prompt_config_creative, | ||||||||||||||||||||||
| tokenizer=self.tokenizer, | ||||||||||||||||||||||
| code_tags=self.cfg.code_tags, | ||||||||||||||||||||||
| examples_type=self.cfg.examples_type, | ||||||||||||||||||||||
| system_message=self.cfg.system_message, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| LOG.info("Prompt used (creative_writing): %s", self.category_prompts["creative_writing"]) | ||||||||||||||||||||||
| # registering default prompt explicitly for hard_prompt | ||||||||||||||||||||||
| self.category_prompts["hard_prompt"] = default_prompt | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| LOG.info("Prompt used (default): %s", default_prompt) | ||||||||||||||||||||||
| return default_prompt | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def fill_prompt(self, data_point, data): | ||||||||||||||||||||||
| """Fill prompt with category-specific prompt config.""" | ||||||||||||||||||||||
| if self.cfg.prompt_format == "openai": | ||||||||||||||||||||||
| return super().fill_prompt(data_point, data) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Select the appropriate prompt based on category. If not defined, forcing fall-back to default prompt | ||||||||||||||||||||||
| category = data_point.get("category") | ||||||||||||||||||||||
| if not category: | ||||||||||||||||||||||
| prompt = self.prompt | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| # will fail if category not in category_prompts as this is unexpected | ||||||||||||||||||||||
| prompt = self.category_prompts[category] | ||||||||||||||||||||||
|
Comment on lines
+111
to
+112
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will raise
Suggested change
Comment on lines
+106
to
+112
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: The current logic falls back to default prompt only when Based on the comment at line 75-76 ("default prompt used for most categories including hard_prompt"), the intent is to fall back to default for unmapped categories. Proposed fix # Select the appropriate prompt based on category. If not defined, forcing fall-back to default prompt
category = data_point.get("category")
- if not category:
- prompt = self.prompt
- else:
- # will fail if category not in category_prompts as this is unexpected
- prompt = self.category_prompts[category]
+ if category and category in self.category_prompts:
+ prompt = self.category_prompts[category]
+ else:
+ prompt = self.prompt🤖 Prompt for AI Agents
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will raise Per CONTRIBUTING.md guidelines: "Don't be overly defensive" - let it fail with a clear error. However, the error message should indicate which category is missing.
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| data_point = deepcopy(data_point) | ||||||||||||||||||||||
| filled_prompt = prompt.fill( | ||||||||||||||||||||||
| data_point, | ||||||||||||||||||||||
| start_assistant_response_key=self.cfg.start_assistant_response_key, | ||||||||||||||||||||||
| chat_template_kwargs=self.cfg.chat_template_kwargs, | ||||||||||||||||||||||
| format_as_string=(self.cfg.inference.endpoint_type == EndpointType.text), | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if self.cfg.prompt_suffix: | ||||||||||||||||||||||
| if isinstance(filled_prompt, list): | ||||||||||||||||||||||
| filled_prompt[-1]["content"] += self.cfg.prompt_suffix | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| filled_prompt += self.cfg.prompt_suffix | ||||||||||||||||||||||
| return filled_prompt | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def log_example_prompt(self, all_data): | ||||||||||||||||||||||
| data_point = deepcopy(all_data[0]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assumes every
uidhas a baseline for itscategory, will raiseKeyErrorif baseline data is missing/partial