Skip to content

Commit 2ad0679

Browse files
committed
Minor fixes, add OpenAI Harmony to requirements
1 parent bca1b16 commit 2ad0679

File tree

4 files changed

+65
-2
lines changed

4 files changed

+65
-2
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import argparse
18+
from pathlib import Path
19+
20+
from inference_endpoint.dataset_manager.predefined.gpqa import GPQA
21+
from inference_endpoint.evaluation.extractor import ABCDExtractor
22+
from inference_endpoint.evaluation.scoring import PassAt1Scorer
23+
24+
25+
def main(args):
26+
# Load the dataset
27+
ds = GPQA.load_from_file(args.dataset_path)
28+
ds.load()
29+
30+
# Create the scorer
31+
scorer = PassAt1Scorer(
32+
GPQA.DATASET_ID,
33+
ds,
34+
args.report_dir,
35+
extractor=ABCDExtractor,
36+
)
37+
38+
# Score the dataset
39+
score = scorer.score()
40+
print(f"Pass@1 Score: {score}")
41+
42+
43+
if __name__ == "__main__":
44+
parser = argparse.ArgumentParser(
45+
description="Evaluate accuracy of the SGLang endpoint on the GPQA dataset",
46+
formatter_class=argparse.RawDescriptionHelpFormatter,
47+
epilog=__doc__,
48+
)
49+
parser.add_argument(
50+
"--dataset-path",
51+
type=Path,
52+
help="Path to the dataset",
53+
default="datasets/gpqa/diamond/gpqa_diamond.parquet",
54+
)
55+
parser.add_argument(
56+
"--report-dir",
57+
type=Path,
58+
help="Path to the report directory",
59+
default="gpqa_sglang_report",
60+
)
61+
args = parser.parse_args()
62+
main(args)

requirements/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ numpy==2.3.4
3030
datasets==4.1.1
3131
sentencepiece==0.2.1
3232
protobuf==6.33.0
33+
openai_harmony==0.0.8
3334

3435
# Color support for cross-platform terminals
3536
colorama==0.4.6

src/inference_endpoint/evaluation/extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class Extractor(ABC):
2525
numeric value plain or inside a LaTeX block.
2626
"""
2727

28-
@abstractmethod
2928
@classmethod
29+
@abstractmethod
3030
def extract(cls, text: str) -> str | None:
3131
raise NotImplementedError
3232

src/inference_endpoint/evaluation/scoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_outputs(self):
7171
for line in f:
7272
event = orjson.loads(line.strip())
7373
if event["event_type"] == SampleEvent.COMPLETE.value:
74-
outputs.append(event["data"])
74+
outputs.append(event)
7575
df = pd.DataFrame(outputs, columns=["sample_uuid", "value"])
7676
df.rename(columns={"value": "output"}, inplace=True)
7777
return df

0 commit comments

Comments
 (0)