-
Notifications
You must be signed in to change notification settings - Fork 155
support structured outputs in hle judge for optional AA compatibility #1186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
0379e3c
c7a5c6a
493a793
ff96906
f86bc95
54c8bc0
534a6c0
adcff37
8442962
8509918
ba33ee1
cf9725b
52d59b4
6539eae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import json | ||
| import logging | ||
|
|
||
| from nemo_skills.evaluation.metrics.math_metrics import MathMetrics | ||
| from nemo_skills.utils import get_logger_name | ||
|
|
||
| LOG = logging.getLogger(get_logger_name(__file__)) | ||
|
|
||
|
|
||
| class HLEAAMetrics(MathMetrics): | ||
| """Metrics for HLE with judge structured output for AA-compatibility.""" | ||
|
|
||
| def _postprocess_judgement(self, prediction: dict) -> dict: | ||
| prediction = prediction.copy() | ||
| try: | ||
| judgement = json.loads(prediction["judgement"]) | ||
| prediction["judgement"] = "Judgement: {}".format(judgement["correct"]) | ||
| except (json.JSONDecodeError, KeyError) as e: | ||
| LOG.debug(f"Failed to parse structured output judgement: {e}") | ||
| prediction["judgement"] = "Judgement: FAILED_TO_POSTPROCESS" | ||
| return prediction | ||
|
|
||
| def update(self, predictions): | ||
| preprocessed_predictions = [self._postprocess_judgement(pred) for pred in predictions] | ||
| super().update(preprocessed_predictions) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,7 @@ | |
| server_params, | ||
| ) | ||
| from nemo_skills.inference.model.base import EndpointType | ||
| from nemo_skills.inference.structured_outputs import STRUCTURED_OUTPUTS | ||
| from nemo_skills.prompt.utils import get_prompt, get_token_count | ||
| from nemo_skills.utils import ( | ||
| chunk_data, | ||
|
|
@@ -218,6 +219,8 @@ class GenerationTaskConfig: | |
| eval_type: str | None = None # "lean4-proof", "math", etc. | ||
| eval_config: dict = field(default_factory=dict) # Config for the evaluator | ||
|
|
||
| structured_output: str | None = None | ||
|
|
||
| def __post_init__(self): | ||
| self._post_init_validate_data() | ||
| self._post_init_validate_server() | ||
|
|
@@ -681,6 +684,9 @@ async def process_single_datapoint(self, data_point, all_data): | |
| "stop_phrases": [self.cfg.stop_phrase] if self.cfg.stop_phrase else None, | ||
| } | ||
|
|
||
| if self.cfg.structured_output in STRUCTURED_OUTPUTS: | ||
| generation_params["response_format"] = STRUCTURED_OUTPUTS[self.cfg.structured_output] | ||
|
|
||
|
Comment on lines
+694
to
+696
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unhandled invalid key When |
||
| if self.cfg.code_execution: | ||
| if self.cfg.override_max_code_executions and self.cfg.total_code_executions_in_prompt is not None: | ||
| generation_params["max_code_executions"] = data_point["total_code_executions"] | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,29 @@ | ||||||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| from typing import Literal | ||||||
|
|
||||||
| from pydantic import BaseModel | ||||||
anowaczynski-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
|
||||||
|
|
||||||
| class HLEJudgeAAResponseFormat(BaseModel): | ||||||
| extracted_final_answer: str | ||||||
| reasoning: str | ||||||
| correct: Literal["yes", "no"] | ||||||
| confidence: int | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||
|
|
||||||
|
|
||||||
| STRUCTURED_OUTPUTS = { | ||||||
| "HLE_JUDGE_AA": HLEJudgeAAResponseFormat, | ||||||
| } | ||||||
anowaczynski-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider validating
structured_outputagainst registry early.If a user specifies a
structured_outputvalue that's not inSTRUCTURED_OUTPUTS, the code silently ignores it without injectingresponse_format. This could lead to unexpected behavior. Per coding guidelines, the code should fail if a user specifies an unsupported argument.Proposed fix in `__post_init__` or `process_single_datapoint`
Add validation in
GenerationTaskConfig.__post_init__:🤖 Prompt for AI Agents