Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions nemo_skills/evaluation/metrics/hleaa_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging

from nemo_skills.evaluation.metrics.math_metrics import MathMetrics
from nemo_skills.utils import get_logger_name

LOG = logging.getLogger(get_logger_name(__file__))


class HLEAAMetrics(MathMetrics):
"""Metrics for HLE with judge structured output for AA-compatibility."""

def _postprocess_judgement(self, prediction: dict) -> dict:
prediction = prediction.copy()
try:
judgement = json.loads(prediction["judgement"])
prediction["judgement"] = "Judgement: {}".format(judgement["correct"])
except (json.JSONDecodeError, KeyError) as e:
LOG.debug(f"Failed to parse structured output judgement: {e}")
prediction["judgement"] = "Judgement: FAILED_TO_POSTPROCESS"
return prediction

def update(self, predictions):
preprocessed_predictions = [self._postprocess_judgement(pred) for pred in predictions]
super().update(preprocessed_predictions)
2 changes: 2 additions & 0 deletions nemo_skills/evaluation/metrics/map_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SweBenchMetrics,
)
from nemo_skills.evaluation.metrics.gradingbench_metrics import GradingBenchMetrics
from nemo_skills.evaluation.metrics.hleaa_metrics import HLEAAMetrics
from nemo_skills.evaluation.metrics.icpc_metrics import ICPCMetrics
from nemo_skills.evaluation.metrics.if_metrics import IFMetrics
from nemo_skills.evaluation.metrics.ioi_metrics import IOIMetrics
Expand All @@ -47,6 +48,7 @@
METRICS_MAP = {
"math": MathMetrics,
"hle": functools.partial(MathMetrics, compute_no_answer=False, answer_key="generation"),
"hle-aa": functools.partial(HLEAAMetrics, compute_no_answer=False, answer_key="generation"),
"frontierscience-olympiad": functools.partial(
MathMetrics, compute_no_answer=False, question_key="question", answer_key="generation"
),
Expand Down
6 changes: 6 additions & 0 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
server_params,
)
from nemo_skills.inference.model.base import EndpointType
from nemo_skills.inference.structured_outputs import STRUCTURED_OUTPUTS
from nemo_skills.prompt.utils import get_prompt, get_token_count
from nemo_skills.utils import (
chunk_data,
Expand Down Expand Up @@ -218,6 +219,8 @@ class GenerationTaskConfig:
eval_type: str | None = None # "lean4-proof", "math", etc.
eval_config: dict = field(default_factory=dict) # Config for the evaluator

structured_output: str | None = None

def __post_init__(self):
self._post_init_validate_data()
self._post_init_validate_server()
Expand Down Expand Up @@ -681,6 +684,9 @@ async def process_single_datapoint(self, data_point, all_data):
"stop_phrases": [self.cfg.stop_phrase] if self.cfg.stop_phrase else None,
}

if self.cfg.structured_output in STRUCTURED_OUTPUTS:
generation_params["response_format"] = STRUCTURED_OUTPUTS[self.cfg.structured_output]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider validating structured_output against registry early.

If a user specifies a structured_output value that's not in STRUCTURED_OUTPUTS, the code silently ignores it without injecting response_format. This could lead to unexpected behavior. Per coding guidelines, the code should fail if a user specifies an unsupported argument.

Proposed fix in `__post_init__` or `process_single_datapoint`

Add validation in GenerationTaskConfig.__post_init__:

def _post_init_validate_params(self):
    # ... existing validations ...
    if self.structured_output is not None and self.structured_output not in STRUCTURED_OUTPUTS:
        raise ValueError(
            f"Unknown structured_output '{self.structured_output}'. "
            f"Valid options: {list(STRUCTURED_OUTPUTS.keys())}"
        )
🤖 Prompt for AI Agents
In `@nemo_skills/inference/generate.py` around lines 695 - 696, The code silently
ignores unknown structured_output values; add a validation in
GenerationTaskConfig.__post_init__ (or call a helper _post_init_validate_params
from __post_init__) that checks if self.structured_output is not None and not in
STRUCTURED_OUTPUTS and raise a ValueError listing the invalid value and valid
keys (referencing STRUCTURED_OUTPUTS and the attribute structured_output); this
ensures process_single_datapoint/generation_params population logic (where
generation_params["response_format"] is set) never silently drops an unsupported
structured_output.


Comment on lines +694 to +696
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unhandled invalid key

When structured_output is set to any non-None value that is not present in STRUCTURED_OUTPUTS, process_single_datapoint will throw a KeyError at STRUCTURED_OUTPUTS[self.cfg.structured_output]. Since this is a user-provided config value (Hydra/CLI via ++structured_output=...), this becomes an unhelpful crash path. Consider validating structured_output in GenerationTaskConfig.__post_init__ (or using .get() with an explicit ValueError listing allowed keys) so users get a clear error message.

if self.cfg.code_execution:
if self.cfg.override_max_code_executions and self.cfg.total_code_executions_in_prompt is not None:
generation_params["max_code_executions"] = data_point["total_code_executions"]
Expand Down
2 changes: 2 additions & 0 deletions nemo_skills/inference/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def generate_async(
tools: list[dict] | None = None,
include_response: bool = False,
extra_body: dict = None,
response_format=None,
) -> dict:
if endpoint_type is None:
# Infering completion type from prompt
Expand All @@ -261,6 +262,7 @@ async def generate_async(
"reasoning_effort": reasoning_effort,
"tools": tools,
"extra_body": extra_body,
"response_format": response_format,
}

# TODO: remove this after we no longer use gpt-oss or it's fixed in vllm
Expand Down
3 changes: 3 additions & 0 deletions nemo_skills/inference/model/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _build_chat_request_params(
reasoning_effort: str | None,
extra_body: dict = None,
tools: list[dict] | None = None,
response_format=None,
) -> dict:
"""
https://github.com/BerriAI/litellm/blob/v1.75.0-nightly/litellm/constants.py#L45-L56
Expand All @@ -72,6 +73,8 @@ def _build_chat_request_params(
"`repetition_penalty` is not supported by Gemini API, please set it to default value `1.0`."
)
assert not extra_body, "`extra_body` is not supported by Gemini API, please set it to None or empty dict"
if response_format is not None:
raise NotImplementedError()

# Vertext AI params: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
# litellm default params: https://github.com/BerriAI/litellm/blob/v1.75.0-nightly/litellm/llms/gemini/chat/transformation.py#L73-L90
Expand Down
4 changes: 4 additions & 0 deletions nemo_skills/inference/model/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _build_chat_request_params(
stop_phrases: list[str] | None = None,
timeout: int | None = None,
top_logprobs: int | None = None,
response_format=None,
**kwargs,
) -> dict:
# Validations
Expand All @@ -48,6 +49,7 @@ def _build_chat_request_params(
if top_k != -1:
raise NotImplementedError("Megatron server does not support top_k parameter.")
assert kwargs.get("tools") is None, "Megatron server does not support tools parameter."
assert response_format is None, "Megatron server does not support response_format parameter."

params = {
"messages": messages,
Expand Down Expand Up @@ -81,6 +83,7 @@ def _build_completion_request_params(
stop_phrases: list[str] | None = None,
timeout: int | None = None,
top_logprobs: int | None = None,
response_format=None,
**kwargs,
) -> dict:
# Parameter validation specific to Megatron
Expand All @@ -93,6 +96,7 @@ def _build_completion_request_params(
if top_k != -1:
raise NotImplementedError("Megatron server does not support top_k parameter.")
assert kwargs.get("tools") is None, "Megatron server does not support tools parameter."
assert response_format is None, "Megatron server does not support response_format parameter."

return {
"prompt": prompt,
Expand Down
3 changes: 3 additions & 0 deletions nemo_skills/inference/model/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _build_completion_request_params(self, **kwargs) -> dict:
assert kwargs.pop("reasoning_effort", None) is None, (
"reasoning_effort is not supported by completion requests."
)
assert kwargs.pop("response_format", None) is None, "response_format is not supported by completion requests."
assert kwargs.pop("top_k", -1) == -1, "`top_k` is not supported by OpenAI API, please set it to -1."
assert kwargs.pop("min_p", 0.0) == 0.0, "`min_p` is not supported by OpenAI API, please set it to 0.0."
assert kwargs.pop("repetition_penalty", 1.0) == 1.0, (
Expand Down Expand Up @@ -100,6 +101,7 @@ def _build_chat_request_params(
reasoning_effort: str | None,
extra_body: dict = None,
tools: list[dict] | None = None,
response_format=None,
) -> dict:
# Validations
if top_k != -1:
Expand All @@ -116,6 +118,7 @@ def _build_chat_request_params(
"timeout": timeout,
"stream": stream,
"tools": tools,
"response_format": response_format,
}

if self._is_reasoning_model(self.model):
Expand Down
2 changes: 2 additions & 0 deletions nemo_skills/inference/model/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _build_chat_request_params(
reasoning_effort: str | None = None,
tools: list[dict] | None = None,
extra_body: dict = None,
response_format=None,
) -> dict:
request = super()._build_chat_request_params(
messages=messages,
Expand All @@ -56,6 +57,7 @@ def _build_chat_request_params(
reasoning_effort=reasoning_effort,
tools=tools,
extra_body=extra_body,
response_format=response_format,
)
# SGLang requires tool_choice in the request body when tools are provided
if tools is not None:
Expand Down
4 changes: 4 additions & 0 deletions nemo_skills/inference/model/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ def _build_completion_request_params(
reasoning_effort: str | None = None,
extra_body: dict = None,
tools: list[dict] | None = None,
response_format=None,
) -> dict:
assert reasoning_effort is None, "reasoning_effort is not supported for text completion requests"
assert tools is None, "tools are not supported for text completion requests"
assert response_format is None, "response_format is not supported for text completion requests"
return {
"prompt": prompt,
"max_tokens": tokens_to_generate,
Expand Down Expand Up @@ -182,6 +184,7 @@ def _build_chat_request_params(
reasoning_effort: str | None = None,
tools: list[dict] | None = None,
extra_body: dict = None,
response_format=None,
) -> dict:
# Process messages to handle image content (VLM support)
processed_messages = []
Expand All @@ -207,6 +210,7 @@ def _build_chat_request_params(
"timeout": timeout,
"extra_body": self._build_request_body(top_k, min_p, repetition_penalty, extra_body=extra_body),
"tools": tools,
"response_format": response_format,
}
if reasoning_effort:
request["allowed_openai_params"] = ["reasoning_effort"]
Expand Down
29 changes: 29 additions & 0 deletions nemo_skills/inference/structured_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal

from pydantic import BaseModel


class HLEJudgeAAResponseFormat(BaseModel):
extracted_final_answer: str
reasoning: str
correct: Literal["yes", "no"]
confidence: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confidence field has no validation constraints. Should be confidence: int = Field(ge=0, le=100) or similar to ensure valid confidence values.

Suggested change
confidence: int
confidence: int = Field(ge=0, le=100, description="Confidence score from 0 to 100")

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!



STRUCTURED_OUTPUTS = {
"HLE_JUDGE_AA": HLEJudgeAAResponseFormat,
}
8 changes: 7 additions & 1 deletion nemo_skills/pipeline/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import List
from typing import List, Optional

import typer

Expand Down Expand Up @@ -458,6 +458,10 @@ def eval(
"",
help="Additional sbatch kwargs to pass to the job scheduler. Values should be provided as a JSON string or as a `dict` if invoking from code.",
),
metric_type: Optional[str] = typer.Option(
None,
help="Specify metric type to use a specific metric calculator.",
),
metrics_kwargs: str = typer.Option(
"",
help="Additional kwargs to pass to the metrics calculator. Values should be provided as a JSON string or as a `dict` if invoking from code.",
Expand Down Expand Up @@ -773,6 +777,8 @@ def eval(
command += f" --wandb_project={wandb_project} "
if data_dir:
command += f" --data_dir={data_dir} "
if metric_type:
command += f" --metric_type={metric_type} "
if metrics_kwargs:
command += f" --metrics_kwargs='{kwargs_to_string(metrics_kwargs)}' "

Expand Down