|
21 | 21 | import argparse |
22 | 22 | import json |
23 | 23 | import logging |
| 24 | +import os |
24 | 25 | import shutil |
25 | 26 | import signal |
26 | 27 | import tempfile |
27 | 28 | import uuid |
| 29 | +from dataclasses import dataclass |
28 | 30 | from pathlib import Path |
29 | 31 | from urllib.parse import urljoin |
30 | 32 |
|
|
63 | 65 | from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient |
64 | 66 | from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer |
65 | 67 | from inference_endpoint.evaluation import Extractor |
66 | | -from inference_endpoint.evaluation.scoring import PassAt1Scorer |
| 68 | +from inference_endpoint.evaluation.scoring import Scorer |
67 | 69 | from inference_endpoint.exceptions import ( |
68 | 70 | ExecutionError, |
69 | 71 | InputValidationError, |
@@ -141,6 +143,17 @@ def on_complete_hook(self, result: QueryResult): |
141 | 143 | self.pbar.update(1) |
142 | 144 |
|
143 | 145 |
|
| 146 | +@dataclass |
| 147 | +class AccuracyConfiguration: |
| 148 | + scorer: Scorer |
| 149 | + extractor: Extractor |
| 150 | + dataset_name: str |
| 151 | + dataset: Dataset |
| 152 | + report_dir: os.PathLike |
| 153 | + ground_truth_column: str |
| 154 | + num_repeats: int |
| 155 | + |
| 156 | + |
144 | 157 | async def run_benchmark_command(args: argparse.Namespace) -> None: |
145 | 158 | """Run performance benchmark in offline, online, or YAML-configured mode. |
146 | 159 |
|
@@ -336,49 +349,6 @@ def _build_config_from_cli( |
336 | 349 | ) |
337 | 350 |
|
338 | 351 |
|
339 | | -def _get_dataset_path(args: argparse.Namespace, config: BenchmarkConfig) -> Path: |
340 | | - """Get dataset path from CLI args or config. |
341 | | -
|
342 | | - CURRENT LIMITATION: Only supports single dataset execution. |
343 | | - Priority: CLI args > config datasets[0] |
344 | | -
|
345 | | - Args: |
346 | | - args: Command arguments |
347 | | - config: BenchmarkConfig |
348 | | -
|
349 | | - Returns: |
350 | | - Path to dataset file |
351 | | -
|
352 | | - Raises: |
353 | | - InputValidationError: If no dataset specified or file doesn't exist |
354 | | -
|
355 | | - TODO: Multi-dataset support |
356 | | - When implemented, this should: |
357 | | - 1. Return list[Path] for multiple datasets |
358 | | - 2. Validate all dataset paths exist |
359 | | - 3. Support dataset interleaving strategies |
360 | | - """ |
361 | | - if hasattr(args, "dataset") and args.dataset: |
362 | | - dataset_path = Path(args.dataset) |
363 | | - else: |
364 | | - # TODO: Multi-dataset - currently just picks single dataset |
365 | | - single_dataset = config.get_single_dataset() |
366 | | - if single_dataset: |
367 | | - dataset_path = Path(single_dataset.path) |
368 | | - else: |
369 | | - logger.error("Dataset required: --dataset PATH or specify in config") |
370 | | - raise InputValidationError( |
371 | | - "Dataset required: --dataset PATH or specify in config" |
372 | | - ) |
373 | | - |
374 | | - # Validate file exists |
375 | | - if not dataset_path.exists(): |
376 | | - logger.error(f"Dataset not found: {dataset_path}") |
377 | | - raise InputValidationError(f"Dataset not found: {dataset_path}") |
378 | | - |
379 | | - return dataset_path |
380 | | - |
381 | | - |
382 | 352 | def _run_benchmark( |
383 | 353 | config: BenchmarkConfig, |
384 | 354 | collect_responses: bool, |
@@ -498,33 +468,31 @@ def _run_benchmark( |
498 | 468 | "top_k": config.model_params.top_k, |
499 | 469 | "repetition_penalty": config.model_params.repetition_penalty, |
500 | 470 | } |
501 | | - accuracy_datasets = [ |
502 | | - DataLoaderFactory.create_loader(dataset, metadata=metadata) |
503 | | - for dataset in accuracy_configs |
504 | | - ] |
505 | 471 |
|
506 | 472 | # Pack the evaluation parameters for each accuracy dataset |
507 | | - for i in range(len(accuracy_configs)): |
508 | | - dataset = accuracy_configs[i] |
509 | | - extractor = Extractor.get(dataset.accuracy_config.extractor) |
510 | | - ground_truth_column = dataset.accuracy_config.ground_truth |
511 | | - scorer = PassAt1Scorer # currently only PassAt1Scorer is supported |
512 | | - # TODO add support for other scorers |
| 473 | + for acc_config in accuracy_configs: |
| 474 | + extractor = Extractor.get(acc_config.accuracy_config.extractor) |
| 475 | + ground_truth_column = acc_config.accuracy_config.ground_truth |
| 476 | + scorer = Scorer.get(acc_config.accuracy_config.eval_method) |
| 477 | + num_repeats = acc_config.accuracy_config.num_repeats |
| 478 | + dataset = DataLoaderFactory.create_loader( |
| 479 | + acc_config, metadata=metadata, num_repeats=num_repeats |
| 480 | + ) |
| 481 | + accuracy_datasets.append(dataset) |
513 | 482 | # TODO add tests and defaults |
514 | 483 | eval_configs.append( |
515 | | - ( |
| 484 | + AccuracyConfiguration( |
516 | 485 | scorer, |
517 | 486 | extractor, |
518 | | - dataset.name, |
519 | | - accuracy_datasets[i], |
| 487 | + acc_config.name, |
| 488 | + dataset, |
520 | 489 | config.report_dir, |
521 | 490 | ground_truth_column, |
| 491 | + num_repeats, |
522 | 492 | ) |
523 | 493 | ) |
524 | | - accuracy_datasets[i].load() |
525 | | - logger.info( |
526 | | - f"Loaded {accuracy_datasets[i]} - {accuracy_datasets[i].num_samples()} samples" |
527 | | - ) |
| 494 | + dataset.load() |
| 495 | + logger.info(f"Loaded {dataset} - {dataset.num_samples()} samples") |
528 | 496 |
|
529 | 497 | else: |
530 | 498 | logger.info("No accuracy datasets provided") |
@@ -659,31 +627,26 @@ def signal_handler(signum, frame): |
659 | 627 | # Always restore original handler |
660 | 628 | signal.signal(signal.SIGINT, old_handler) |
661 | 629 | accuracy_scores = {} |
662 | | - for ( |
663 | | - scorer, |
664 | | - extractor, |
665 | | - dataset_id, |
666 | | - dataset, |
667 | | - report_dir, |
668 | | - ground_truth_column, |
669 | | - ) in eval_configs: |
670 | | - scorer_instance = scorer( |
671 | | - dataset_id, |
672 | | - dataset, |
673 | | - report_dir, |
674 | | - extractor=extractor, |
675 | | - ground_truth_column=ground_truth_column, |
| 630 | + for eval_config in eval_configs: |
| 631 | + scorer_instance = eval_config.scorer( |
| 632 | + eval_config.dataset_name, |
| 633 | + eval_config.dataset, |
| 634 | + eval_config.report_dir, |
| 635 | + extractor=eval_config.extractor, |
| 636 | + ground_truth_column=eval_config.ground_truth_column, |
676 | 637 | ) |
677 | 638 | score, n_repeats = scorer_instance.score() |
678 | | - accuracy_scores[dataset_id] = { |
679 | | - "dataset_id": dataset_id, |
680 | | - "num_samples": len(dataset.data), |
681 | | - "extractor": extractor.__name__, |
682 | | - "ground_truth_column": ground_truth_column, |
| 639 | + accuracy_scores[eval_config.dataset_name] = { |
| 640 | + "dataset_name": eval_config.dataset_name, |
| 641 | + "num_samples": len(eval_config.dataset.data), |
| 642 | + "extractor": eval_config.extractor.__name__, |
| 643 | + "ground_truth_column": eval_config.ground_truth_column, |
683 | 644 | "score": score, |
684 | 645 | "n_repeats": n_repeats, |
685 | 646 | } |
686 | | - logger.info(f"Score for {dataset_id}: {score} ({n_repeats} repeats)") |
| 647 | + logger.info( |
| 648 | + f"Score for {eval_config.dataset_name}: {score} ({n_repeats} repeats)" |
| 649 | + ) |
687 | 650 |
|
688 | 651 | # Prefer authoritative metrics from the session report |
689 | 652 | report = getattr(sess, "report", None) |
|
0 commit comments