Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions nemo_skills/evaluation/metrics/hleaa_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging

from nemo_skills.evaluation.metrics.math_metrics import MathMetrics
from nemo_skills.utils import get_logger_name

LOG = logging.getLogger(get_logger_name(__file__))


class HLEAAMetrics(MathMetrics):
"""Metrics for HLE with judge structured output for AA-compatibility."""

def _postprocess_judgement(self, prediction: dict) -> dict:
prediction = prediction.copy()
try:
judgement = json.loads(prediction["judgement"])
prediction["judgement"] = "Judgement: {}".format(judgement["correct"])
except (json.JSONDecodeError, KeyError) as e:
LOG.debug(f"Failed to parse structured output judgement: {e}")
prediction["judgement"] = "Judgement: FAILED_TO_POSTPROCESS"
return prediction

def update(self, predictions):
preprocessed_predictions = [self._postprocess_judgement(pred) for pred in predictions]
super().update(preprocessed_predictions)
2 changes: 2 additions & 0 deletions nemo_skills/evaluation/metrics/map_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SweBenchMetrics,
)
from nemo_skills.evaluation.metrics.gradingbench_metrics import GradingBenchMetrics
from nemo_skills.evaluation.metrics.hleaa_metrics import HLEAAMetrics
from nemo_skills.evaluation.metrics.icpc_metrics import ICPCMetrics
from nemo_skills.evaluation.metrics.if_metrics import IFMetrics
from nemo_skills.evaluation.metrics.ioi_metrics import IOIMetrics
Expand All @@ -47,6 +48,7 @@
METRICS_MAP = {
"math": MathMetrics,
"hle": functools.partial(MathMetrics, compute_no_answer=False, answer_key="generation"),
"hle-aa": functools.partial(HLEAAMetrics, compute_no_answer=False, answer_key="generation"),
"frontierscience-olympiad": functools.partial(
MathMetrics, compute_no_answer=False, question_key="question", answer_key="generation"
),
Expand Down
6 changes: 6 additions & 0 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
server_params,
)
from nemo_skills.inference.model.base import EndpointType
from nemo_skills.inference.structured_outputs import STRUCTURED_OUTPUTS
from nemo_skills.prompt.utils import get_prompt, get_token_count
from nemo_skills.utils import (
chunk_data,
Expand Down Expand Up @@ -218,6 +219,8 @@ class GenerationTaskConfig:
eval_type: str | None = None # "lean4-proof", "math", etc.
eval_config: dict = field(default_factory=dict) # Config for the evaluator

structured_output: str | None = None

def __post_init__(self):
self._post_init_validate_data()
self._post_init_validate_server()
Expand Down Expand Up @@ -681,6 +684,9 @@ async def process_single_datapoint(self, data_point, all_data):
"stop_phrases": [self.cfg.stop_phrase] if self.cfg.stop_phrase else None,
}

if self.cfg.structured_output in STRUCTURED_OUTPUTS:
generation_params["response_format"] = STRUCTURED_OUTPUTS[self.cfg.structured_output]
Comment on lines +687 to +688
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider validating structured_output against registry early.

If a user specifies a structured_output value that's not in STRUCTURED_OUTPUTS, the code silently ignores it without injecting response_format. This could lead to unexpected behavior. Per coding guidelines, the code should fail if a user specifies an unsupported argument.

Proposed fix in `__post_init__` or `process_single_datapoint`

Add validation in GenerationTaskConfig.__post_init__:

def _post_init_validate_params(self):
    # ... existing validations ...
    if self.structured_output is not None and self.structured_output not in STRUCTURED_OUTPUTS:
        raise ValueError(
            f"Unknown structured_output '{self.structured_output}'. "
            f"Valid options: {list(STRUCTURED_OUTPUTS.keys())}"
        )
🤖 Prompt for AI Agents
In `@nemo_skills/inference/generate.py` around lines 695 - 696, The code silently
ignores unknown structured_output values; add a validation in
GenerationTaskConfig.__post_init__ (or call a helper _post_init_validate_params
from __post_init__) that checks if self.structured_output is not None and not in
STRUCTURED_OUTPUTS and raise a ValueError listing the invalid value and valid
keys (referencing STRUCTURED_OUTPUTS and the attribute structured_output); this
ensures process_single_datapoint/generation_params population logic (where
generation_params["response_format"] is set) never silently drops an unsupported
structured_output.


if self.cfg.code_execution:
if self.cfg.override_max_code_executions and self.cfg.total_code_executions_in_prompt is not None:
generation_params["max_code_executions"] = data_point["total_code_executions"]
Expand Down
2 changes: 2 additions & 0 deletions nemo_skills/inference/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def generate_async(
tools: list[dict] | None = None,
include_response: bool = False,
extra_body: dict = None,
response_format=None,
) -> dict:
if endpoint_type is None:
# Infering completion type from prompt
Expand All @@ -261,6 +262,7 @@ async def generate_async(
"reasoning_effort": reasoning_effort,
"tools": tools,
"extra_body": extra_body,
"response_format": response_format,
}

# TODO: remove this after we no longer use gpt-oss or it's fixed in vllm
Expand Down
3 changes: 3 additions & 0 deletions nemo_skills/inference/model/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _build_chat_request_params(
reasoning_effort: str | None,
extra_body: dict = None,
tools: list[dict] | None = None,
response_format=None,
) -> dict:
"""
https://github.com/BerriAI/litellm/blob/v1.75.0-nightly/litellm/constants.py#L45-L56
Expand All @@ -72,6 +73,8 @@ def _build_chat_request_params(
"`repetition_penalty` is not supported by Gemini API, please set it to default value `1.0`."
)
assert not extra_body, "`extra_body` is not supported by Gemini API, please set it to None or empty dict"
if response_format is not None:
raise NotImplementedError()

# Vertext AI params: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
# litellm default params: https://github.com/BerriAI/litellm/blob/v1.75.0-nightly/litellm/llms/gemini/chat/transformation.py#L73-L90
Expand Down
4 changes: 4 additions & 0 deletions nemo_skills/inference/model/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _build_chat_request_params(
stop_phrases: list[str] | None = None,
timeout: int | None = None,
top_logprobs: int | None = None,
response_format=None,
**kwargs,
) -> dict:
# Validations
Expand All @@ -48,6 +49,7 @@ def _build_chat_request_params(
if top_k != -1:
raise NotImplementedError("Megatron server does not support top_k parameter.")
assert kwargs.get("tools") is None, "Megatron server does not support tools parameter."
assert response_format is None, "Megatron server does not support response_format parameter."

params = {
"messages": messages,
Expand Down Expand Up @@ -81,6 +83,7 @@ def _build_completion_request_params(
stop_phrases: list[str] | None = None,
timeout: int | None = None,
top_logprobs: int | None = None,
response_format=None,
**kwargs,
) -> dict:
# Parameter validation specific to Megatron
Expand All @@ -93,6 +96,7 @@ def _build_completion_request_params(
if top_k != -1:
raise NotImplementedError("Megatron server does not support top_k parameter.")
assert kwargs.get("tools") is None, "Megatron server does not support tools parameter."
assert response_format is None, "Megatron server does not support response_format parameter."

return {
"prompt": prompt,
Expand Down
3 changes: 3 additions & 0 deletions nemo_skills/inference/model/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _build_completion_request_params(self, **kwargs) -> dict:
assert kwargs.pop("reasoning_effort", None) is None, (
"reasoning_effort is not supported by completion requests."
)
assert kwargs.pop("response_format", None) is None, "response_format is not supported by completion requests."
assert kwargs.pop("top_k", -1) == -1, "`top_k` is not supported by OpenAI API, please set it to -1."
assert kwargs.pop("min_p", 0.0) == 0.0, "`min_p` is not supported by OpenAI API, please set it to 0.0."
assert kwargs.pop("repetition_penalty", 1.0) == 1.0, (
Expand Down Expand Up @@ -100,6 +101,7 @@ def _build_chat_request_params(
reasoning_effort: str | None,
extra_body: dict = None,
tools: list[dict] | None = None,
response_format=None,
) -> dict:
# Validations
if top_k != -1:
Expand All @@ -116,6 +118,7 @@ def _build_chat_request_params(
"timeout": timeout,
"stream": stream,
"tools": tools,
"response_format": response_format,
}

if self._is_reasoning_model(self.model):
Expand Down
2 changes: 2 additions & 0 deletions nemo_skills/inference/model/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _build_chat_request_params(
reasoning_effort: str | None = None,
tools: list[dict] | None = None,
extra_body: dict = None,
response_format=None,
) -> dict:
request = super()._build_chat_request_params(
messages=messages,
Expand All @@ -56,6 +57,7 @@ def _build_chat_request_params(
reasoning_effort=reasoning_effort,
tools=tools,
extra_body=extra_body,
response_format=response_format,
)
# SGLang requires tool_choice in the request body when tools are provided
if tools is not None:
Expand Down
4 changes: 4 additions & 0 deletions nemo_skills/inference/model/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ def _build_completion_request_params(
reasoning_effort: str | None = None,
extra_body: dict = None,
tools: list[dict] | None = None,
response_format=None,
) -> dict:
assert reasoning_effort is None, "reasoning_effort is not supported for text completion requests"
assert tools is None, "tools are not supported for text completion requests"
assert response_format is None, "response_format is not supported for text completion requests"
return {
"prompt": prompt,
"max_tokens": tokens_to_generate,
Expand Down Expand Up @@ -182,6 +184,7 @@ def _build_chat_request_params(
reasoning_effort: str | None = None,
tools: list[dict] | None = None,
extra_body: dict = None,
response_format=None,
) -> dict:
# Process messages to handle image content (VLM support)
processed_messages = []
Expand All @@ -207,6 +210,7 @@ def _build_chat_request_params(
"timeout": timeout,
"extra_body": self._build_request_body(top_k, min_p, repetition_penalty, extra_body=extra_body),
"tools": tools,
"response_format": response_format,
}
if reasoning_effort:
request["allowed_openai_params"] = ["reasoning_effort"]
Expand Down
29 changes: 29 additions & 0 deletions nemo_skills/inference/structured_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal

from pydantic import BaseModel


class HLEJudgeAAResponseFormat(BaseModel):
extracted_final_answer: str
reasoning: str
correct: Literal["yes", "no"]
confidence: int


STRUCTURED_OUTPUTS = {
"HLE_JUDGE_AA": HLEJudgeAAResponseFormat,
}
8 changes: 7 additions & 1 deletion nemo_skills/pipeline/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import List
from typing import List, Optional

import typer

Expand Down Expand Up @@ -458,6 +458,10 @@ def eval(
"",
help="Additional sbatch kwargs to pass to the job scheduler. Values should be provided as a JSON string or as a `dict` if invoking from code.",
),
metric_type: Optional[str] = typer.Option(
None,
help="Specify metric type to use a specific metric calculator.",
),
metrics_kwargs: str = typer.Option(
"",
help="Additional kwargs to pass to the metrics calculator. Values should be provided as a JSON string or as a `dict` if invoking from code.",
Expand Down Expand Up @@ -773,6 +777,8 @@ def eval(
command += f" --wandb_project={wandb_project} "
if data_dir:
command += f" --data_dir={data_dir} "
if metric_type:
command += f" --metric_type={metric_type} "
if metrics_kwargs:
command += f" --metrics_kwargs='{kwargs_to_string(metrics_kwargs)}' "

Expand Down
28 changes: 28 additions & 0 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,31 @@ def test_server_metadata_from_num_tasks(tmp_path):
assert server_cmd.script.num_gpus == server_config["num_gpus"]
assert groups[0].hardware.num_gpus == server_config["num_gpus"]
assert groups[0].hardware.num_tasks == server_cmd.script.num_tasks


def test_judge_generations_with_structured_output(tmp_path):
cmd = (
f"ns eval "
f" --server_type=openai "
f" --model=nvidia/nemotron-3-nano-30b-a3b "
f" --server_address=https://integrate.api.nvidia.com/v1 "
f" --benchmarks=hle "
f" --output_dir={tmp_path} "
f" --judge_model=nvidia/nemotron-3-nano-30b-a3b "
f" --judge_server_address=https://integrate.api.nvidia.com/v1 "
f" --judge_server_type=openai "
f" --metric_type=hle-aa "
f' --extra_judge_args="++structured_output=HLE_JUDGE_AA"'
f" ++max_samples=2 "
)
subprocess.run(cmd, shell=True, check=True)

# checking that output exists and has the expected format
with open(f"{tmp_path}/eval-results/hle/output.jsonl") as fin:
data = [json.loads(line) for line in fin.readlines()]
judgements = [json.loads(data[i]["judgement"]) for i in range(len(data))]
expected_keys = {"extracted_final_answer", "reasoning", "correct", "confidence"}
assert set(judgements[0].keys()) == expected_keys
assert set(judgements[1].keys()) == expected_keys
assert judgements[0]["correct"] in {"yes", "no"}
assert judgements[1]["correct"] in {"yes", "no"}
Loading