Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/inference_endpoint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,10 @@ def _add_online_specific_args(parser):
)
parser.add_argument(
"--concurrency",
type=int,
help="Max concurrent requests (required when --load-pattern=concurrency)",
type=str,
help="Max concurrent requests (required when --load-pattern=concurrency). "
"Can be a single value (e.g., '10') or comma-separated list (e.g., '10,20,30') "
"to run multiple benchmarks sequentially.",
)


Expand Down
152 changes: 142 additions & 10 deletions src/inference_endpoint/commands/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,36 @@ def _build_config_from_cli(
Raises:
InputValidationError: If required params missing
"""
# Parse concurrency argument (can be single value or comma-separated list)
concurrency_value: int | list[int] | None = None
if concurrency_str := getattr(args, "concurrency", None):
if isinstance(concurrency_str, int):
concurrency_value = concurrency_str
elif "," in concurrency_str:
Comment on lines +279 to +281
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--concurrency is defined as type=str in argparse, so concurrency_str will never be an int here. Dropping the dead isinstance(concurrency_str, int) branch will simplify the logic and avoid misleading future readers.

Suggested change
if isinstance(concurrency_str, int):
concurrency_value = concurrency_str
elif "," in concurrency_str:
if "," in concurrency_str:

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems valid?

# Parse comma-separated list
try:
concurrency_value = [int(c.strip()) for c in concurrency_str.split(",")]
except ValueError as e:
raise InputValidationError(
f"Invalid concurrency value '{concurrency_str}': all values must be integers"
) from e
else:
# Parse single integer
try:
concurrency_value = int(concurrency_str)
except ValueError as e:
raise InputValidationError(
f"Invalid concurrency value '{concurrency_str}': must be an integer"
) from e
Comment on lines +278 to +296
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for parsing the concurrency string has separate try-except blocks for handling a single value and a comma-separated list, which is a bit repetitive. This can be simplified by always splitting the string by commas and then checking the length of the resulting list. This would reduce code duplication and make the logic more concise.

Suggested change
if concurrency_str := getattr(args, "concurrency", None):
if "," in concurrency_str:
# Parse comma-separated list
try:
concurrency_value = [int(c.strip()) for c in concurrency_str.split(",")]
except ValueError as e:
raise InputValidationError(
f"Invalid concurrency value '{concurrency_str}': all values must be integers"
) from e
else:
# Parse single integer
try:
concurrency_value = int(concurrency_str)
except ValueError as e:
raise InputValidationError(
f"Invalid concurrency value '{concurrency_str}': must be an integer"
) from e
if concurrency_str := getattr(args, "concurrency", None):
try:
concurrencies = [int(c.strip()) for c in concurrency_str.split(",")]
if len(concurrencies) == 1:
concurrency_value = concurrencies[0]
else:
concurrency_value = concurrencies
except ValueError as e:
raise InputValidationError(
f"Invalid concurrency value '{concurrency_str}': all values must be integers"
) from e


# Determine load pattern (CLI override or mode default)
if load_pattern_arg := getattr(args, "load_pattern", None):
load_pattern_type = LoadPatternType(load_pattern_arg)
else:
match benchmark_mode:
case "offline":
load_pattern_type = LoadPatternType.MAX_THROUGHPUT
case "online" if getattr(args, "concurrency", None):
case "online" if concurrency_value is not None:
load_pattern_type = LoadPatternType.CONCURRENCY
case "online":
load_pattern_type = LoadPatternType.POISSON
Expand Down Expand Up @@ -309,7 +331,7 @@ def _build_config_from_cli(
load_pattern=LoadPattern(
type=load_pattern_type,
target_qps=getattr(args, "target_qps", None),
target_concurrency=getattr(args, "concurrency", None),
target_concurrency=concurrency_value,
),
runtime=RuntimeConfig(
min_duration_ms=args.duration * 1000
Expand Down Expand Up @@ -359,9 +381,90 @@ def _run_benchmark(
test_mode: TestMode,
benchmark_mode: TestType | None,
) -> None:
"""Execute the actual benchmark with full lifecycle management.
"""Execute benchmark(s) - either single run or multiple runs for different concurrency values.

This function handles the top-level orchestration:
1. If target_concurrency is a single value or not applicable, run one benchmark
2. If target_concurrency is a list, run multiple benchmarks sequentially
3. Each benchmark run gets its own subdirectory (concurrency_{value})
4. Resources are fully cleaned up between runs to ensure isolation

Args:
config: Validated BenchmarkConfig (immutable Pydantic model).
collect_responses: Whether to store full response text.
test_mode: What to collect - PERF, ACC, or BOTH.
benchmark_mode: Execution mode - OFFLINE or ONLINE.
"""
# Determine base report directory
if config.report_dir:
base_report_dir = Path(config.report_dir)
else:
base_report_dir = get_default_report_path()
Comment on lines +399 to +402
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The report_dir parameter, which can be controlled by a user, is used to construct file paths for benchmark reports without proper sanitization. This creates a path traversal vulnerability, allowing an attacker to write files outside of the intended directory. This could lead to overwriting arbitrary files, denial of service, or placing malicious files in sensitive locations.

Suggested change
if config.report_dir:
base_report_dir = Path(config.report_dir)
else:
base_report_dir = get_default_report_path()
if config.report_dir:
user_path = Path(config.report_dir)
# Resolve the user-provided path to an absolute path to prevent traversal attacks
base_report_dir = user_path.resolve()
# Define a safe base directory (e.g., current working directory) and ensure the path is within it
safe_base_dir = Path.cwd().resolve()
if not str(base_report_dir).startswith(str(safe_base_dir)):
raise InputValidationError(f"Invalid report_dir: '{config.report_dir}'. Path must be within '{safe_base_dir}'.")
else:
base_report_dir = get_default_report_path()


# Check if we need to run multiple benchmarks for different concurrency values
load_pattern = config.settings.load_pattern
target_concurrency = load_pattern.target_concurrency

# If concurrency mode with list of values, run multiple benchmarks
if load_pattern.type == LoadPatternType.CONCURRENCY and isinstance(
target_concurrency, list
):
logger.info(
f"Running {len(target_concurrency)} benchmarks with concurrency values: {target_concurrency}"
)

for i, concurrency_val in enumerate(target_concurrency):
logger.info(
f"\n{'='*80}\nBenchmark {i+1}/{len(target_concurrency)}: Concurrency = {concurrency_val}\n{'='*80}"
)

# Create subdirectory for this concurrency value
concurrency_report_dir = base_report_dir / f"concurrency_{concurrency_val}"
Comment on lines +421 to +422
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If target_concurrency contains duplicate values (e.g., [10, 10, 20]), runs will write into the same concurrency_{value} directory and overwrite/mix artifacts (including config.yaml). Consider either validating that the list is unique, or making the directory name unambiguous (e.g., prefix with the index like run_{i+1}_concurrency_{value}).

Suggested change
# Create subdirectory for this concurrency value
concurrency_report_dir = base_report_dir / f"concurrency_{concurrency_val}"
# Create subdirectory for this concurrency value (include run index to avoid collisions)
concurrency_report_dir = base_report_dir / f"run_{i+1}_concurrency_{concurrency_val}"

Copilot uses AI. Check for mistakes.
Comment on lines +421 to +422

try:
_run_single_benchmark(
config=config,
collect_responses=collect_responses,
test_mode=test_mode,
benchmark_mode=benchmark_mode,
report_dir=concurrency_report_dir,
concurrency_value=concurrency_val,
)
except Exception as e:
logger.error(f"Benchmark failed for concurrency={concurrency_val}: {e}")
raise

This function orchestrates the complete benchmark execution:
logger.info(
f"Completed benchmark {i+1}/{len(target_concurrency)} (concurrency={concurrency_val})"
)

logger.info(
f"\n{'='*80}\nAll benchmarks completed successfully!\nResults saved to: {base_report_dir}\n{'='*80}"
)

else:
# Single benchmark run
_run_single_benchmark(
config=config,
collect_responses=collect_responses,
test_mode=test_mode,
benchmark_mode=benchmark_mode,
report_dir=base_report_dir,
concurrency_value=None,
)


def _run_single_benchmark(
config: BenchmarkConfig,
collect_responses: bool,
test_mode: TestMode,
benchmark_mode: TestType | None,
report_dir: Path,
concurrency_value: int | None = None,
) -> None:
"""Execute a single benchmark run with full lifecycle management.

This function orchestrates a complete benchmark execution:
1. Load tokenizer for the target model
2. Load and validate dataset using DataLoaderFactory
3. Setup runtime settings and scheduler
Expand All @@ -370,6 +473,11 @@ def _run_benchmark(
6. Collect and report results
7. Clean up resources (always, even on error)

When called as part of multiple concurrency runs:
- Creates fresh worker processes for isolation
- Clears event hooks to prevent cross-contamination
- Uses dedicated subdirectory for this run's results

Architecture notes:
- This is a SYNCHRONOUS function (not async) because HTTPEndpointClient
manages its own event loop in a separate thread
Expand All @@ -382,7 +490,6 @@ def _run_benchmark(
- Disabled for offline mode (max throughput focus)

Args:
args: Command arguments containing output paths, verbosity, etc.
config: Validated BenchmarkConfig (immutable Pydantic model).
Contains all benchmark parameters from CLI or YAML.
collect_responses: Whether to store full response text.
Expand All @@ -391,6 +498,9 @@ def _run_benchmark(
or BOTH (metrics + responses).
benchmark_mode: Execution mode - OFFLINE (max throughput) or
ONLINE (sustained QPS). Affects streaming and scheduling.
report_dir: Directory to write reports and results for this run.
concurrency_value: If set, overrides config's target_concurrency for this run.
Used when iterating through multiple concurrency values.

Raises:
InputValidationError: If model/dataset cannot be loaded or validated.
Expand All @@ -411,12 +521,29 @@ def _run_benchmark(
model_name = config.submission_ref.model
config.model_params.name = model_name

if config.report_dir:
report_dir = Path(config.report_dir)
else:
report_dir = get_default_report_path()

# Ensure report directory exists
report_dir.mkdir(parents=True, exist_ok=True)

# If concurrency_value is provided, create a modified config with single concurrency
if concurrency_value is not None:
# Create a new config with the specific concurrency value
config = BenchmarkConfig(
**{
**config.model_dump(),
"settings": Settings(
**{
**config.settings.model_dump(),
"load_pattern": LoadPattern(
**{
**config.settings.load_pattern.model_dump(),
"target_concurrency": concurrency_value,
}
),
}
),
Comment on lines +530 to +543
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reconstructing nested Pydantic models via model_dump() + manual re-wrapping is hard to read and easy to get wrong as the schema evolves. Prefer using Pydantic’s copy/update utilities (e.g., model_copy(update=...) on BenchmarkConfig, Settings, and LoadPattern) to express “same config except target_concurrency=…” more directly and reduce the risk of losing defaults/aliases or introducing subtle serialization differences.

Suggested change
config = BenchmarkConfig(
**{
**config.model_dump(),
"settings": Settings(
**{
**config.settings.model_dump(),
"load_pattern": LoadPattern(
**{
**config.settings.load_pattern.model_dump(),
"target_concurrency": concurrency_value,
}
),
}
),
config = config.model_copy(
update={
"settings": config.settings.model_copy(
update={
"load_pattern": config.settings.load_pattern.model_copy(
update={"target_concurrency": concurrency_value}
)
}
)

Copilot uses AI. Check for mistakes.
}
Comment on lines +529 to +544
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This manual rebuild of BenchmarkConfig is quite verbose/error-prone. Since this is a Pydantic v2 model, it’s typically simpler (and less likely to miss fields) to use config.model_copy(update=..., deep=True) to override only settings.load_pattern.target_concurrency.

Suggested change
# Create a new config with the specific concurrency value
config = BenchmarkConfig(
**{
**config.model_dump(),
"settings": Settings(
**{
**config.settings.model_dump(),
"load_pattern": LoadPattern(
**{
**config.settings.load_pattern.model_dump(),
"target_concurrency": concurrency_value,
}
),
}
),
}
# Create a new config with the specific concurrency value by copying and updating
config = config.model_copy(
update={
"settings": config.settings.model_copy(
update={
"load_pattern": config.settings.load_pattern.model_copy(
update={"target_concurrency": concurrency_value},
deep=True,
)
},
deep=True,
)
},
deep=True,

Copilot uses AI. Check for mistakes.
)
Comment on lines +530 to +545
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The method used to update the BenchmarkConfig with a new concurrency_value is quite complex and hard to maintain. It involves deeply nested dictionary unpacking and model reconstruction.

Pydantic's model_copy(update=...) method provides a much cleaner and more robust way to create a modified copy of a frozen model. This would make the code more readable and less prone to errors if the model structure changes.

        config = config.model_copy(
            update={
                "settings": config.settings.model_copy(
                    update={
                        "load_pattern": config.settings.load_pattern.model_copy(
                            update={"target_concurrency": concurrency_value}
                        )
                    }
                )
            }
        )


config.to_yaml_file(report_dir / "config.yaml")

if model_name:
Expand Down Expand Up @@ -736,6 +863,11 @@ def signal_handler(signum, frame):
sample_issuer.shutdown()
http_client.shutdown()
shutil.rmtree(tmp_dir, ignore_errors=True)

# Clear event hooks to ensure clean state for next benchmark run
# This is critical when running multiple concurrency benchmarks
SampleEventHandler.clear_hooks()

except Exception as e:
if config.verbose:
logger.warning(f"Cleanup error: {e}")
48 changes: 43 additions & 5 deletions src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import ClassVar

import yaml
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator

from .. import metrics
from .ruleset_base import BenchmarkSuiteRuleset
Expand Down Expand Up @@ -266,13 +266,36 @@ class LoadPattern(BaseModel):
- max_throughput: target_qps used for calculating total queries (offline, optional with default)
- poisson: target_qps sets scheduler rate (online, required - validated)
- concurrency: issue at fixed target_concurrency (online, required - validated)

target_concurrency can be either:
- Single int: Run one benchmark with that concurrency level
- List of ints: Run multiple benchmarks sequentially, one per concurrency level
"""

type: LoadPatternType = LoadPatternType.MAX_THROUGHPUT
target_qps: float | None = (
None # Target QPS - required for poisson pattern, optional otherwise
)
target_concurrency: int | None = None # For concurrency mode, ignored otherwise
target_concurrency: int | list[int] | None = (
None # For concurrency mode, ignored otherwise
)

@field_validator("target_concurrency", mode="before")
@classmethod
def validate_target_concurrency(cls, v):
"""Validate target_concurrency accepts int or list of ints."""
if v is None:
return v
# Accept single int
if isinstance(v, int):
return v
# Accept list of ints
if isinstance(v, list):
return v
# Try to convert if it's something else (shouldn't happen with proper YAML)
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says the validator will 'Try to convert' other types, but the implementation immediately raises. Either remove/update the comment, or implement the described conversion (e.g., accepting numeric strings) so the code and documentation match.

Suggested change
# Try to convert if it's something else (shouldn't happen with proper YAML)
# Reject any other types (this should not happen with proper YAML)

Copilot uses AI. Check for mistakes.
raise ValueError(
f"target_concurrency must be an integer or list of integers, got {type(v)}"
)
Comment on lines +283 to +298
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validator’s docstring and comments claim it validates “list of ints” and “tries to convert”, but it currently returns any list unchanged (including list[str], list[float], etc.) and performs no conversion. Either (a) actually validate/convert list elements here (e.g., ensure each element is an int and/or coerce string ints if that’s desired), or (b) remove this validator and rely on the later validate_load_pattern logic (to avoid duplicated/partial validation paths).

Copilot uses AI. Check for mistakes.


class ClientSettings(BaseModel):
Expand Down Expand Up @@ -524,10 +547,25 @@ def validate_load_pattern(self, benchmark_mode: TestType) -> None:
)
elif load_pattern_type == LoadPatternType.CONCURRENCY:
# Concurrency pattern requires target_concurrency > 0
if not target_concurrency or target_concurrency <= 0:
# Can be single int or list of ints
if target_concurrency is None:
raise ValueError(
"Concurrency load pattern requires target_concurrency to be specified. "
"Specify number of concurrent requests (e.g., target_concurrency: 10 or target_concurrency: [10, 20, 30] in YAML or --concurrency 10 in CLI)"
)

# Validate single int or list of ints
if isinstance(target_concurrency, list):
if len(target_concurrency) == 0:
raise ValueError("target_concurrency list cannot be empty")
for i, conc in enumerate(target_concurrency):
if not isinstance(conc, int) or conc <= 0:
raise ValueError(
f"target_concurrency[{i}] must be a positive integer, got {conc}"
)
elif not isinstance(target_concurrency, int) or target_concurrency <= 0:
raise ValueError(
"Concurrency load pattern requires target_concurrency > 0. "
"Specify number of concurrent requests (e.g., target_concurrency: 10 under load_pattern in YAML or --concurrency 10 in CLI)"
f"target_concurrency must be a positive integer or list of positive integers, got {target_concurrency}"
)

def validate_client_settings(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/inference_endpoint/load_generator/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,11 @@ def __init__(self, runtime_settings: RuntimeSettings, sample_order_cls):
super().__init__(runtime_settings, sample_order_cls)
assert runtime_settings.load_pattern is not None
target_concurrency = runtime_settings.load_pattern.target_concurrency
if target_concurrency is None or target_concurrency <= 0:
if (
target_concurrency is None
or (isinstance(target_concurrency, int) and target_concurrency <= 0)
or (isinstance(target_concurrency, list) and len(target_concurrency) == 0)
):
raise ValueError(
f"target_concurrency must be > 0 for CONCURRENCY load pattern, got {target_concurrency}"
)
Comment on lines +382 to 389
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message is now misleading for the list case: an empty list fails the check but the message says “must be > 0”. Update the message to reflect the accepted shapes (positive integer or non-empty list of positive integers), or split the exception messages by failure mode (e.g., None vs empty list vs non-positive int) to make debugging easier.

Copilot uses AI. Check for mistakes.
Comment on lines +382 to 389
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The validation for target_concurrency in ConcurrencyScheduler is incorrect. It allows target_concurrency to be a list, but the rest of the scheduler's logic expects it to be an integer. Specifically, the comparison self._inflight >= self._target_concurrency on line 421 will raise a TypeError if self._target_concurrency is a list.

Given the orchestration logic in _run_single_benchmark, this scheduler should only ever receive a single integer for target_concurrency. The validation should enforce this to prevent this runtime error.

        if not isinstance(target_concurrency, int) or target_concurrency <= 0:
            raise ValueError(
                f"target_concurrency must be a positive integer for CONCURRENCY load pattern, got {target_concurrency}"
            )

Expand Down
Loading