Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ repos:
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
hooks:
- id: mypy
args: [--config-file=./pyproject.toml, src, tests]
pass_filenames: false
verbose: true
entry: bash -c 'mypy "$@" || true' --
additional_dependencies:
- types-PyYAML==6.0.12
- types-requests>=2.32.4

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
hooks:
Expand Down
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,17 @@ exclude_lines = [
"class .*\\\\bProtocol\\):",
"@(abc\\\\.)?abstractmethod",
]

[tool.mypy]
ignore_missing_imports = true
exclude = [
"openai_types_gen\\.py$",
"metrics/reporter\\.py$",
]

[[tool.mypy.overrides]]
module = [
"inference_endpoint.openai.openai_types_gen",
"inference_endpoint.metrics.reporter",
]
ignore_errors = true
47 changes: 32 additions & 15 deletions src/inference_endpoint/commands/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from urllib.parse import urljoin

from tqdm import tqdm
Expand All @@ -40,7 +41,6 @@
APIType,
BenchmarkConfig,
ClientSettings,
Dataset,
DatasetType,
EndpointConfig,
LoadPattern,
Expand All @@ -54,8 +54,12 @@
TestMode,
TestType,
)
from inference_endpoint.config.schema import (
Dataset as DatasetConfig,
)
from inference_endpoint.config.yaml_loader import ConfigError, ConfigLoader
from inference_endpoint.core.types import QueryResult
from inference_endpoint.dataset_manager.dataset import Dataset
from inference_endpoint.dataset_manager.factory import DataLoaderFactory
from inference_endpoint.endpoint_client.config import HTTPClientConfig
from inference_endpoint.endpoint_client.cpu_affinity import pin_loadgen
Expand Down Expand Up @@ -134,7 +138,7 @@ def on_complete_hook(self, result: QueryResult):
if self.pbar:
self.pbar.set_postfix(refresh=True, errors=len(self.errors))
elif self.collect_responses:
self.responses[result.id] = result.response_output
self.responses[result.id] = result.get_response_output_string()

if self.pbar:
self.pbar.update(1)
Expand Down Expand Up @@ -227,9 +231,7 @@ async def run_benchmark_command(args: argparse.Namespace) -> None:
elif benchmark_mode_str in ("offline", "online"):
# ===== CLI MODE - Build config from CLI params =====
benchmark_mode = TestType(benchmark_mode_str) # TestType values are lowercase
effective_config: BenchmarkConfig = _build_config_from_cli(
args, benchmark_mode_str
)
effective_config = _build_config_from_cli(args, benchmark_mode_str)
test_mode = (
TestMode(args.mode) if getattr(args, "mode", None) else TestMode.PERF
)
Expand Down Expand Up @@ -294,7 +296,7 @@ def _build_config_from_cli(
version="1.0",
type=TestType.OFFLINE if benchmark_mode == "offline" else TestType.ONLINE,
datasets=[
Dataset(
DatasetConfig(
name=args.dataset.stem,
type=DatasetType.PERFORMANCE,
path=str(args.dataset),
Expand Down Expand Up @@ -343,7 +345,6 @@ def _build_config_from_cli(
api_type=api_type,
),
metrics=Metrics(),
baseline=None, # CLI mode doesn't use baseline
report_dir=report_dir,
timeout=timeout,
verbose=verbose_level > 0,
Expand Down Expand Up @@ -464,6 +465,21 @@ def _run_benchmark(
if len(accuracy_configs) > 0:
# Pack the evaluation parameters for each accuracy dataset
for acc_config in accuracy_configs:
# Type narrowing: ensure accuracy_config is not None
assert (
acc_config.accuracy_config is not None
), f"accuracy_config must be set for dataset {acc_config.name}"
# Type narrowing: ensure required fields are not None
assert (
acc_config.accuracy_config.eval_method is not None
), f"eval_method must be set for dataset {acc_config.name}"
assert (
acc_config.accuracy_config.extractor is not None
), f"extractor must be set for dataset {acc_config.name}"
assert (
acc_config.accuracy_config.ground_truth is not None
), f"ground_truth must be set for dataset {acc_config.name}"

dataset = DataLoaderFactory.create_loader(
acc_config, num_repeats=acc_config.accuracy_config.num_repeats
)
Expand All @@ -475,7 +491,7 @@ def _run_benchmark(
Extractor.get(acc_config.accuracy_config.extractor),
acc_config.name,
dataset,
config.report_dir,
report_dir,
acc_config.accuracy_config.ground_truth,
acc_config.accuracy_config.num_repeats,
)
Expand Down Expand Up @@ -554,18 +570,18 @@ def _run_benchmark(

# Create endpoint client
endpoints = config.endpoint_config.endpoints
assert endpoints is not None
num_workers = config.settings.client.workers

logger.info(f"Connecting: {endpoints}")
tmp_dir = tempfile.mkdtemp(prefix="inference_endpoint_")

try:
api_type: APIType = config.endpoint_config.api_type
assert api_type is not None
http_config = HTTPClientConfig(
endpoint_urls=[
urljoin(e, config.endpoint_config.api_type.default_route())
for e in endpoints
],
api_type=config.endpoint_config.api_type,
endpoint_urls=[urljoin(e, api_type.default_route()) for e in endpoints],
api_type=api_type,
num_workers=num_workers,
record_worker_events=config.settings.client.record_worker_events,
event_logs_dir=report_dir,
Expand Down Expand Up @@ -595,7 +611,7 @@ def _run_benchmark(
report_dir=report_dir,
tokenizer_override=tokenizer,
accuracy_datasets=accuracy_datasets,
max_shutdown_timeout_s=config.timeout if config.timeout else None,
max_shutdown_timeout_s=config.timeout,
dump_events_log=True,
)

Expand All @@ -622,6 +638,7 @@ def signal_handler(signum, frame):
ground_truth_column=eval_config.ground_truth_column,
)
score, n_repeats = scorer_instance.score()
assert eval_config.dataset.data is not None
accuracy_scores[eval_config.dataset_name] = {
"dataset_name": eval_config.dataset_name,
"num_samples": len(eval_config.dataset.data),
Expand Down Expand Up @@ -665,7 +682,7 @@ def signal_handler(signum, frame):
logger.warning(f" ... +{len(response_collector.errors) - 3} more")

try:
results = {
results: dict[str, Any] = {
"config": {
"endpoint": endpoints,
"mode": test_mode,
Expand Down
14 changes: 10 additions & 4 deletions src/inference_endpoint/commands/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ async def run_probe_command(args: argparse.Namespace) -> None:
):
try:
# Schedule receive on client's event loop and await the result
assert client.loop is not None, "Client loop should be initialized"
future = asyncio.run_coroutine_threadsafe(client.recv(), client.loop)
result = await asyncio.wrap_future(future)

Expand Down Expand Up @@ -163,21 +164,26 @@ async def run_probe_command(args: argparse.Namespace) -> None:
continue
latency_ms = (time.time() - start_times[query_id]) * 1000

# Normalize response_output for logging
response_output = result.get_response_output_string()
if response_output is None:
response_output = ""

if result.error:
errors.append(f"{query_id}: {result.error}")
else:
latencies.append(latency_ms)
responses.append((query_id, result.response_output))
responses.append(
(query_id, response_output if response_output else "<EMPTY>")
)

# Simple progress indicator
if (
len(received_ids) % max(1, num_expected // 10) == 0
or len(received_ids) == num_expected
):
output_preview = (
result.response_output[:100]
if result.response_output
else "(no output)"
response_output[:100] if response_output else "(no output)"
)
logger.info(
f" Processed {len(received_ids)}/{num_expected} responses : {query_id} : {output_preview}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ def _disallow_instantiation(cls, *args, **kwargs):
)


_Dataset.__new__ = _disallow_instantiation
_Dataset.__new__ = _disallow_instantiation # type: ignore[method-assign]
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,4 @@ def _disallow_instantiation(cls, *args, **kwargs):
)


_Model.__new__ = _disallow_instantiation
_Model.__new__ = _disallow_instantiation # type: ignore[method-assign]
35 changes: 23 additions & 12 deletions src/inference_endpoint/config/rulesets/mlcommons/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ class PerModelRuleset:
min_duration_ms_valid: int = (
10 * 60 * 1000
) # Minimum duration in milliseconds required for a valid run
max_duration_ms_valid: int = None # Maximum duration in milliseconds. Used as a timeout / kill for a benchmark run. Set to None for no timeout.
min_sample_count_valid: int = None # Minimum number of samples required to be sent to the SUT for a valid run, if None, no minimum is enforced
max_duration_ms_valid: int | None = (
None # Maximum duration in milliseconds. Used as a timeout / kill for a benchmark run. Set to None for no timeout.
)
min_sample_count_valid: int | None = (
None # Minimum number of samples required to be sent to the SUT for a valid run, if None, no minimum is enforced
)
metric: type[metrics.Metric] | None = (
None # any subclass of Metric. Used as metric to evaluate the performance of the benchmark.
)
Expand All @@ -55,7 +59,9 @@ class PerQueryRuleset(PerModelRuleset):
target_latency_percentile: float = (
99.0 # Percentile of per-query latencies to use for metric comparison
)
max_latency_threshold_ms: int = None # Maximum latency threshold in milliseconds for the specified percentile latency allowed for a valid run.
max_latency_threshold_ms: int | None = (
None # Maximum latency threshold in milliseconds for the specified percentile latency allowed for a valid run.
)
reported_metrics: list[type[metrics.Metric]] = field(
default_factory=lambda: [metrics.Throughput]
)
Expand All @@ -65,10 +71,10 @@ class PerQueryRuleset(PerModelRuleset):
class TokenBasedRuleset(PerModelRuleset):
min_sample_count_valid: int = 270336
metric: type[metrics.Metric] = metrics.Throughput
max_ttft_latency_ms: int = (
max_ttft_latency_ms: int | None = (
None # Maximum TTFT latency in milliseconds allowed for a valid run
)
max_tpot_latency_ms: int = (
max_tpot_latency_ms: int | None = (
None # Maximum TPoT latency in milliseconds allowed for a valid run
)
reported_metrics: list[type[metrics.Metric]] = field(
Expand Down Expand Up @@ -148,24 +154,27 @@ def apply_user_config(

ruleset = self.benchmark_rulesets[model][opt_prio]

metric_target = None
if user_config.user_metric_target:
metric_target: metrics.Metric | None = None
if user_config.user_metric_target and ruleset.metric:
metric_target = ruleset.metric(user_config.user_metric_target)

reported_metrics = []
reported_metrics: list[metrics.Metric] = []
for mtype in ruleset.reported_metrics:
if mtype == ruleset.metric:
if mtype == ruleset.metric and metric_target:
reported_metrics.append(metric_target)
elif mtype == metrics.TTFT:
assert isinstance(ruleset, TokenBasedRuleset)
reported_metrics.append(metrics.TTFT(ruleset.max_ttft_latency_ms))
elif mtype == metrics.TPOT:
assert isinstance(ruleset, TokenBasedRuleset)
reported_metrics.append(metrics.TPOT(ruleset.max_tpot_latency_ms))
elif mtype == metrics.QueryLatency and ruleset.metric == metrics.Throughput:
# If we specify throughput and want to also report per query latency, infer latency from inverting qps.
reported_metrics.append(
metrics.QueryLatency(target_qps=user_config.user_metric_target)
)
elif mtype == metrics.Throughput and ruleset.metric == metrics.QueryLatency:
assert user_config.user_metric_target is not None
# If we specify per query latency, infer qps by inverting
target_qps = 1000 / user_config.user_metric_target
reported_metrics.append(metrics.Throughput(target_qps=target_qps))
Expand All @@ -182,7 +191,7 @@ def apply_user_config(
if user_config.max_duration_ms is not None:
max_duration_ms = user_config.max_duration_ms
assert (
max_duration_ms >= min_duration_ms
max_duration_ms is not None and max_duration_ms >= min_duration_ms
), "Max duration must be greater than or equal to min duration"

n_samples_from_dataset = model.dataset.size
Expand All @@ -198,13 +207,15 @@ def apply_user_config(
min_sample_count = user_config.min_sample_count

return _RuntimeSettings(
metric_target=metric_target,
metric_target=metric_target
if metric_target is not None
else metrics.Throughput(0.0),
reported_metrics=reported_metrics,
min_duration_ms=min_duration_ms,
max_duration_ms=max_duration_ms,
n_samples_from_dataset=n_samples_from_dataset,
n_samples_to_issue=total_sample_count,
min_sample_count=min_sample_count,
min_sample_count=min_sample_count if min_sample_count is not None else 1,
rng_sched=random.Random(self.scheduler_rng_seed),
rng_sample_index=random.Random(self.sample_index_rng_seed),
load_pattern=None, # not part user config
Expand Down
2 changes: 1 addition & 1 deletion src/inference_endpoint/config/runtime_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _from_config_default(
# Apply overrides
kwargs.update(overrides)

return cls(**kwargs)
return cls(**kwargs) # type: ignore[arg-type]

def total_samples_to_issue(
self, padding_factor: float = 1.1, align_to_dataset_size: bool = True
Expand Down
18 changes: 12 additions & 6 deletions src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,22 @@ class Dataset(BaseModel):
format: str | None = None
samples: int | None = None
eval_method: EvalMethod | None = None
parser: dict | None = None
parser: dict[str, str] | None = None
accuracy_config: AccuracyConfig | None = None


class AccuracyConfig(BaseModel):
"""Accuracy configuration.
The eval_method is the method to use to evaluate the accuracy of the model. Currently only "pass_at_1" is supported.
The ground_truth is the column in the dataset that contains the ground truth. Defaults to "ground_truth" if not specified.
The extractor is the extractor to use to extract the ground truth from the output. Currently "boxed_math_extractor" and "abcd_extractor" are supported.
The num_repeats is the number of times to repeat the dataset for evaluation. Defaults to 1 if not specified.

The eval_method is the method to use to evaluate the accuracy of the model.
Currently only "pass_at_1" is supported.
The ground_truth is the column in the dataset that contains the ground truth.
Defaults to "ground_truth" if not specified.
The extractor is the extractor to use to extract the ground truth from the output.
Currently "boxed_math_extractor" and "abcd_extractor" are supported.
The num_repeats is the number of times to repeat the dataset for evaluation.
Defaults to 1 if not specified.

Example:
accuracy_config:
eval_method: "pass_at_1"
Expand Down Expand Up @@ -373,7 +379,7 @@ class BenchmarkConfig(BaseModel):
# workers are assigned endpoints in a round-robin manner
endpoint_config: EndpointConfig = Field(default_factory=EndpointConfig)
report_dir: Path | None = None
timeout: int | None = None
timeout: float = 300.0 # default timeout of 300 seconds
verbose: bool = False
# CPU affinity for loadgen and worker processes:
# - True = auto (compute optimal NUMA-aware plan)
Expand Down
Loading
Loading