Skip to content

Commit 0a17fea

Browse files
committed
Fixup integration
Signed-off-by: Rashid Kaleem <230885705+arekay-nv@users.noreply.github.com>
1 parent 0682950 commit 0a17fea

File tree

5 files changed

+96
-85
lines changed

5 files changed

+96
-85
lines changed

examples/04_GPTOSS120B_Example/sglang_gptoss_120b_example.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ datasets:
2222
eval_method: "pass_at_1"
2323
ground_truth: "answer"
2424
extractor: "boxed_math_extractor"
25+
num_repeats: 8
2526
- name: "gpqa_gptoss_sglang"
2627
type: "accuracy"
2728
accuracy_config:
2829
eval_method: "pass_at_1"
2930
extractor: "abcd_extractor"
30-
31+
num_repeats: 5
32+
# LCB - 3
3133
settings:
3234
runtime:
3335
min_duration_ms: 300
@@ -55,4 +57,4 @@ endpoint_config:
5557
api_key: null
5658
api_type: "sglang"
5759

58-
report_dir: "results/sglang_gptoss_120b_benchmark_mlperf_13_JAN_26/"
60+
report_dir: "results/sglang_gptoss_120b_benchmark_mlperf_13_JAN_26_DP4_FULL/"

src/inference_endpoint/commands/benchmark.py

Lines changed: 44 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
import argparse
2222
import json
2323
import logging
24+
import os
2425
import shutil
2526
import signal
2627
import tempfile
2728
import uuid
29+
from dataclasses import dataclass
2830
from pathlib import Path
2931
from urllib.parse import urljoin
3032

@@ -63,7 +65,7 @@
6365
from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient
6466
from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer
6567
from inference_endpoint.evaluation import Extractor
66-
from inference_endpoint.evaluation.scoring import PassAt1Scorer
68+
from inference_endpoint.evaluation.scoring import Scorer
6769
from inference_endpoint.exceptions import (
6870
ExecutionError,
6971
InputValidationError,
@@ -141,6 +143,17 @@ def on_complete_hook(self, result: QueryResult):
141143
self.pbar.update(1)
142144

143145

146+
@dataclass
147+
class AccuracyConfiguration:
148+
scorer: Scorer
149+
extractor: Extractor
150+
dataset_name: str
151+
dataset: Dataset
152+
report_dir: os.PathLike
153+
ground_truth_column: str
154+
num_repeats: int
155+
156+
144157
async def run_benchmark_command(args: argparse.Namespace) -> None:
145158
"""Run performance benchmark in offline, online, or YAML-configured mode.
146159
@@ -336,49 +349,6 @@ def _build_config_from_cli(
336349
)
337350

338351

339-
def _get_dataset_path(args: argparse.Namespace, config: BenchmarkConfig) -> Path:
340-
"""Get dataset path from CLI args or config.
341-
342-
CURRENT LIMITATION: Only supports single dataset execution.
343-
Priority: CLI args > config datasets[0]
344-
345-
Args:
346-
args: Command arguments
347-
config: BenchmarkConfig
348-
349-
Returns:
350-
Path to dataset file
351-
352-
Raises:
353-
InputValidationError: If no dataset specified or file doesn't exist
354-
355-
TODO: Multi-dataset support
356-
When implemented, this should:
357-
1. Return list[Path] for multiple datasets
358-
2. Validate all dataset paths exist
359-
3. Support dataset interleaving strategies
360-
"""
361-
if hasattr(args, "dataset") and args.dataset:
362-
dataset_path = Path(args.dataset)
363-
else:
364-
# TODO: Multi-dataset - currently just picks single dataset
365-
single_dataset = config.get_single_dataset()
366-
if single_dataset:
367-
dataset_path = Path(single_dataset.path)
368-
else:
369-
logger.error("Dataset required: --dataset PATH or specify in config")
370-
raise InputValidationError(
371-
"Dataset required: --dataset PATH or specify in config"
372-
)
373-
374-
# Validate file exists
375-
if not dataset_path.exists():
376-
logger.error(f"Dataset not found: {dataset_path}")
377-
raise InputValidationError(f"Dataset not found: {dataset_path}")
378-
379-
return dataset_path
380-
381-
382352
def _run_benchmark(
383353
config: BenchmarkConfig,
384354
collect_responses: bool,
@@ -498,33 +468,31 @@ def _run_benchmark(
498468
"top_k": config.model_params.top_k,
499469
"repetition_penalty": config.model_params.repetition_penalty,
500470
}
501-
accuracy_datasets = [
502-
DataLoaderFactory.create_loader(dataset, metadata=metadata)
503-
for dataset in accuracy_configs
504-
]
505471

506472
# Pack the evaluation parameters for each accuracy dataset
507-
for i in range(len(accuracy_configs)):
508-
dataset = accuracy_configs[i]
509-
extractor = Extractor.get(dataset.accuracy_config.extractor)
510-
ground_truth_column = dataset.accuracy_config.ground_truth
511-
scorer = PassAt1Scorer # currently only PassAt1Scorer is supported
512-
# TODO add support for other scorers
473+
for acc_config in accuracy_configs:
474+
extractor = Extractor.get(acc_config.accuracy_config.extractor)
475+
ground_truth_column = acc_config.accuracy_config.ground_truth
476+
scorer = Scorer.get(acc_config.accuracy_config.eval_method)
477+
num_repeats = acc_config.accuracy_config.num_repeats
478+
dataset = DataLoaderFactory.create_loader(
479+
acc_config, metadata=metadata, num_repeats=num_repeats
480+
)
481+
accuracy_datasets.append(dataset)
513482
# TODO add tests and defaults
514483
eval_configs.append(
515-
(
484+
AccuracyConfiguration(
516485
scorer,
517486
extractor,
518-
dataset.name,
519-
accuracy_datasets[i],
487+
acc_config.name,
488+
dataset,
520489
config.report_dir,
521490
ground_truth_column,
491+
num_repeats,
522492
)
523493
)
524-
accuracy_datasets[i].load()
525-
logger.info(
526-
f"Loaded {accuracy_datasets[i]} - {accuracy_datasets[i].num_samples()} samples"
527-
)
494+
dataset.load()
495+
logger.info(f"Loaded {dataset} - {dataset.num_samples()} samples")
528496

529497
else:
530498
logger.info("No accuracy datasets provided")
@@ -659,31 +627,26 @@ def signal_handler(signum, frame):
659627
# Always restore original handler
660628
signal.signal(signal.SIGINT, old_handler)
661629
accuracy_scores = {}
662-
for (
663-
scorer,
664-
extractor,
665-
dataset_id,
666-
dataset,
667-
report_dir,
668-
ground_truth_column,
669-
) in eval_configs:
670-
scorer_instance = scorer(
671-
dataset_id,
672-
dataset,
673-
report_dir,
674-
extractor=extractor,
675-
ground_truth_column=ground_truth_column,
630+
for eval_config in eval_configs:
631+
scorer_instance = eval_config.scorer(
632+
eval_config.dataset_name,
633+
eval_config.dataset,
634+
eval_config.report_dir,
635+
extractor=eval_config.extractor,
636+
ground_truth_column=eval_config.ground_truth_column,
676637
)
677638
score, n_repeats = scorer_instance.score()
678-
accuracy_scores[dataset_id] = {
679-
"dataset_id": dataset_id,
680-
"num_samples": len(dataset.data),
681-
"extractor": extractor.__name__,
682-
"ground_truth_column": ground_truth_column,
639+
accuracy_scores[eval_config.dataset_name] = {
640+
"dataset_name": eval_config.dataset_name,
641+
"num_samples": len(eval_config.dataset.data),
642+
"extractor": eval_config.extractor.__name__,
643+
"ground_truth_column": eval_config.ground_truth_column,
683644
"score": score,
684645
"n_repeats": n_repeats,
685646
}
686-
logger.info(f"Score for {dataset_id}: {score} ({n_repeats} repeats)")
647+
logger.info(
648+
f"Score for {eval_config.dataset_name}: {score} ({n_repeats} repeats)"
649+
)
687650

688651
# Prefer authoritative metrics from the session report
689652
report = getattr(sess, "report", None)

src/inference_endpoint/config/schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,19 @@ class AccuracyConfig(BaseModel):
207207
The eval_method is the method to use to evaluate the accuracy of the model. Currently only "pass_at_1" is supported.
208208
The ground_truth is the column in the dataset that contains the ground truth. Defaults to "ground_truth" if not specified.
209209
The extractor is the extractor to use to extract the ground truth from the output. Currently "boxed_math_extractor" and "abcd_extractor" are supported.
210+
The num_repeats is the number of times to repeat the dataset for evaluation. Defaults to 1 if not specified.
210211
Example:
211212
accuracy_config:
212213
eval_method: "pass_at_1"
213214
ground_truth: "answer"
214215
extractor: "boxed_math_extractor"
216+
num_repeats: 5
215217
"""
216218

217219
eval_method: str | None = None
218220
ground_truth: str = "ground_truth"
219221
extractor: str | None = None
222+
num_repeats: int = 1
220223

221224

222225
class RuntimeConfig(BaseModel):

src/inference_endpoint/dataset_manager/factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class DataLoaderFactory:
4242
def create_loader(
4343
config: Dataset,
4444
metadata: dict | None = None,
45+
**kwargs,
4546
) -> Dataset:
4647
"""Create appropriate dataset loader based on format.
4748
@@ -54,7 +55,7 @@ def create_loader(
5455
remap = config.parser
5556
name = config.name
5657
if name in Dataset.PREDEFINED:
57-
return Dataset.PREDEFINED[name].get_dataloader()
58+
return Dataset.PREDEFINED[name].get_dataloader(**kwargs)
5859
if format is not None:
5960
format = DatasetFormat(format)
6061

src/inference_endpoint/evaluation/scoring.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
# limitations under the License.
1515

1616

17+
import inspect
1718
import os
1819
from abc import ABC, abstractmethod
1920
from pathlib import Path
21+
from typing import ClassVar
2022

2123
import numpy as np
2224
import orjson
@@ -33,6 +35,46 @@ class Scorer(ABC):
3335
can be compared against the ground truth.
3436
"""
3537

38+
PREDEFINED: ClassVar[dict[str, type["Scorer"]]] = {}
39+
40+
def __init_subclass__(
41+
cls,
42+
scorer_id: str | None = None,
43+
**kwargs,
44+
):
45+
super().__init_subclass__(**kwargs)
46+
47+
if not inspect.isabstract(cls):
48+
if scorer_id is None:
49+
scorer_id = cls.__name__
50+
cls.SCORER_ID = scorer_id
51+
Scorer.PREDEFINED[scorer_id] = cls
52+
53+
@classmethod
54+
def get(cls, name: str) -> type["Scorer"]:
55+
"""Look up an Scorer subclass by its registered name.
56+
57+
Args:
58+
name: str, the registered scorer name
59+
60+
Returns:
61+
Scorer subclass
62+
63+
Raises:
64+
KeyError: If no scorer with the given name is found
65+
"""
66+
try:
67+
return Scorer.PREDEFINED[name]
68+
except KeyError as e:
69+
raise KeyError(
70+
f"Scorer '{name}' is not registered - available scorers: {Scorer.available_scorers()}"
71+
) from e
72+
73+
@classmethod
74+
def available_scorers(cls) -> list[str]:
75+
"""Return the list of registered scorer names."""
76+
return list(Scorer.PREDEFINED.keys())
77+
3678
def __init__(
3779
self,
3880
dataset_name: str,
@@ -125,7 +167,7 @@ def score(self) -> tuple[float, int]:
125167
return np.mean(scores), n_repeats
126168

127169

128-
class PassAt1Scorer(Scorer):
170+
class PassAt1Scorer(Scorer, scorer_id="pass_at_1"):
129171
"""Implements pass@1 scoring as defined by Artificial Analysis.
130172
pass@1 means the model gets exactly one attempt to produce the correct answer.
131173
The score is 1 if the output matches the ground truth exactly, 0 otherwise.

0 commit comments

Comments
 (0)