diff --git a/docs/load_generator/design.md b/docs/load_generator/design.md new file mode 100644 index 00000000..77216a6d --- /dev/null +++ b/docs/load_generator/design.md @@ -0,0 +1,1086 @@ +# Async Load Generator Design + +## Overview + +The load generator is the central scheduling component that controls _when_ and _how_ +samples are issued to inference endpoints during benchmarking. It is fully async with a +single-thread, single-event-loop-per-process constraint. + +## Architecture + +A `BenchmarkSession` runs one or more **phases** sequentially. Each phase has its own +`RuntimeSettings`, `Dataset`, and `LoadStrategy`. Phases are categorized as either +**tracked** (produces a performance metrics report) or **untracked** (performance is not evaluated). + +Multiple performance phases allow testing different configurations (QPS targets, +concurrency levels, datasets) against the same server instance within a single session, +each producing an independent report. + +``` +BenchmarkSession.run(phases) + | + +-- STARTED + +-- [saturation] strategy.execute() → NO drain (keep in-flight saturated) + +-- [perf phase 1] START_PERFORMANCE_TRACKING → strategy.execute() → drain → STOP_PERFORMANCE_TRACKING → snapshot report + +-- [saturation] strategy.execute() → drain + +-- [perf phase 2] START_PERFORMANCE_TRACKING → strategy.execute() → drain → STOP_PERFORMANCE_TRACKING → snapshot report + +-- [accuracy x N] strategy.execute() → drain (uuid maps collected) + +-- ENDED + | + +-- return SessionResult { perf_results: [PhaseResult, ...], accuracy_results: [...] } +``` + +Each performance phase is bracketed by `START_PERFORMANCE_TRACKING` / +`STOP_PERFORMANCE_TRACKING` events, which the `MetricsAggregator` uses to +scope its tracked counters and duration. At the end of each perf phase, +metrics are snapshotted from the KVStoreReader and a `Report` is built. + +> **TODO:** The current `MetricsAggregator` does not support per-phase scoping. +> It maintains a single set of counters and series across all tracking windows. +> To support multiple perf phases with independent reports, the aggregator will +> need either: (a) a `RESET_METRICS` event that clears counters/series between +> phases, or (b) per-phase metric namespacing (e.g., prefix keys with phase name), +> or (c) the report builder computes deltas by snapshotting before and after each +> phase. This will be addressed in a future change to the `MetricsAggregator`. +> Option (b) is the most-likely planned change as it is the most robust. + +Saturation phases exist to bring the endpoint to steady-state before a +performance measurement. In-flight requests are **not drained** at the end +of a saturation phase — the next phase starts immediately with concurrency +already at the target level. Common uses: + +- Fill KV caches so perf phase measures warm inference, not cold start +- Ramp concurrency to target level before measuring at that level +- Warm connection pools and OS TCP buffers + +### Load Strategies + +Three load patterns, three implementations — each uses the optimal async primitive +for its scheduling semantics, validated by benchmarking: + +| LoadPatternType | Strategy | Mechanism | Best At | +| ----------------- | --------------------- | ---------------------------- | ------------------- | +| POISSON | `TimedIssueStrategy` | `loop.call_at` (default) | ≤50k QPS | +| POISSON (precise) | `TimedIssueStrategy` | `run_in_executor(busy_wait)` | Sub-100μs precision | +| MAX_THROUGHPUT | `BurstStrategy` | `loop.call_at` / sync batch | Max fire rate | +| CONCURRENCY | `ConcurrencyStrategy` | `asyncio.Semaphore` | Fixed concurrency | + +**Default for Poisson is `loop.call_at`:** Sub-millisecond timing precision (600–700μs) +with zero GIL contention and low response latency (0.6–1.4ms). No thread pool overhead. +Degrades above 100k+ QPS where the callback queue saturates. + +`run_in_executor(busy_wait)` is available as an opt-in for workloads requiring sub-100μs +timing precision. It achieves 65–92μs but introduces GIL contention that adds 6ms +response latency at low QPS (<1k). At mid-range QPS (5k–50k), latency is comparable. + +### Optional: Separate Timer Process + +For workloads requiring both precise timing AND minimal response latency (e.g., edge +inference with tight TPOT budgets), the timer can run in a dedicated process: + +``` +Timer Process (dedicated): + - Owns a tight busy-wait loop, no GIL contention + - Sends (sample_index: int) via ZMQ PUSH at precise times + +Main Process: + - Receives indices via ZMQ PULL + - Loads data, builds Query, issues via HTTPEndpointClient + - Runs receiver coroutine — event loop is never blocked +``` + +This eliminates the GIL contention that causes `run_in_executor` to add approximately +6ms response latency at low QPS. However, it adds ZMQ IPC latency (10–50μs) to timing +precision. + +**Not suitable for ConcurrencyStrategy**: the timer process has no visibility into +completion events, so it cannot gate on in-flight count. Concurrency mode always runs +in-process. + +--- + +## Components + +### BenchmarkSession + +**File:** `src/inference_endpoint/load_generator/session.py` + +Async orchestrator. Runs phases sequentially on the shared event loop. + +```python +class PhaseType(str, Enum): + """Phase types control tracking and reporting behavior.""" + PERFORMANCE = "performance" # Tracked, produces a report + ACCURACY = "accuracy" # Untracked, for eval scoring + SATURATION = "saturation" # Untracked, ramp up concurrency before perf phase + + +@dataclass(frozen=True, slots=True) +class PhaseConfig: + """Configuration for a single benchmark phase.""" + name: str + runtime_settings: RuntimeSettings + dataset: Dataset + phase_type: PhaseType = PhaseType.PERFORMANCE + + +class BenchmarkSession: + def __init__( + self, + issuer: SampleIssuer, + event_publisher: EventRecordPublisher, + on_sample_complete: Callable[[QueryResult], None] | None = None, + ): ... + + async def run(self, phases: list[PhaseConfig]) -> SessionResult: ... + def stop(self) -> None: ... +``` + +**`run(phases)`** lifecycle: + +1. Publish `SessionEventType.STARTED` +2. Start receiver coroutine (`_receive_responses`) +3. For each phase: + a. Create `SampleOrder` and `LoadStrategy` from phase settings + b. Set `self._current_dataset` to phase dataset + c. **SATURATION**: execute strategy, **do not drain** in-flight. No tracking + events, no report. Purpose: bring endpoint to steady-state concurrency + (e.g., fill KV caches, warm up connection pools). The next phase starts + immediately with concurrency already at the target level. + d. **PERFORMANCE**: publish `START_PERFORMANCE_TRACKING`, execute strategy, + drain in-flight, publish `STOP_PERFORMANCE_TRACKING`. Snapshot metrics + from KVStoreReader → build `PhaseResult`. + e. **ACCURACY**: execute strategy, drain in-flight. No tracking events. + UUID map collected for eval scoring. +4. Publish `SessionEventType.ENDED` +5. Return `SessionResult` (contains `PhaseResult` per perf phase + accuracy maps) + +**Saturation phases** are particularly important for concurrency-based benchmarks. +A common pattern: + +```python +phases = [ + # Ramp up to target concurrency, fill endpoint caches + PhaseConfig("warmup", warmup_settings, dataset, PhaseType.SATURATION), + # Measured performance run + PhaseConfig("perf", perf_settings, dataset, PhaseType.PERFORMANCE), + # Accuracy eval (uses same warmed endpoint) + PhaseConfig("accuracy", acc_settings, acc_dataset, PhaseType.ACCURACY), +] +``` + +Or multiple performance sweeps with saturation between each: + +```python +phases = [ + PhaseConfig("saturate_c32", sat_32, dataset, PhaseType.SATURATION), + PhaseConfig("perf_c32", perf_32, dataset, PhaseType.PERFORMANCE), + PhaseConfig("saturate_c64", sat_64, dataset, PhaseType.SATURATION), + PhaseConfig("perf_c64", perf_64, dataset, PhaseType.PERFORMANCE), + PhaseConfig("accuracy", acc_settings, acc_dataset, PhaseType.ACCURACY), +] +``` + +### PhaseIssuer + +**File:** `src/inference_endpoint/load_generator/session.py` (internal to session) + +Per-phase state holder that wraps the issue logic. Created fresh for each phase, +holds the phase-scoped `uuid_to_index` map and inflight counter. Passed to +strategies as a callable (`phase_issuer.issue`). + +Using an object instead of a closure makes per-phase state explicit, testable +independently, and avoids the awkward tuple return pattern. + +```python +class PhaseIssuer: + """Wraps sample issuance for a single benchmark phase.""" + + __slots__ = ("_dataset", "_issuer", "_publisher", "_stop_check", + "uuid_to_index", "inflight", "issued_count") + + def __init__( + self, + dataset: Dataset, + issuer: SampleIssuer, + publisher: EventRecordPublisher, + stop_check: Callable[[], bool], + ): + self._dataset = dataset + self._issuer = issuer + self._publisher = publisher + self._stop_check = stop_check + self.uuid_to_index: dict[str, int] = {} + self.inflight: int = 0 + self.issued_count: int = 0 + + def issue(self, sample_index: int) -> str | None: + """Load data, build Query, publish ISSUED, send to endpoint. + + Returns query_id on success, None if session is stopping. + """ + if self._stop_check(): + return None + query_id = uuid.uuid4().hex + data = self._dataset.load_sample(sample_index) + query = Query(id=query_id, data=data) + self.uuid_to_index[query_id] = sample_index + ts = time.monotonic_ns() + self._publisher.publish(EventRecord( + event_type=SampleEventType.ISSUED, + timestamp_ns=ts, + sample_uuid=query_id, + data=PromptData(text=data.get("prompt")), + )) + self._issuer.issue(query) + self.inflight += 1 + self.issued_count += 1 + return query_id +``` + +The strategy calls `phase_issuer.issue(idx)`. After the phase completes, +the session reads `phase_issuer.uuid_to_index` and `phase_issuer.issued_count` +to build the `PhaseResult`. + +**UUID generation before Query construction** avoids the old `Sample` catch-22. +`Query` is a frozen `msgspec.Struct` — all fields set at construction, no mutation. + +**`_receive_responses()`** — concurrent coroutine, purely async: + +```python +async def _receive_responses(self): + while True: + resp = await self.issuer.recv() + if resp is None: + break + self._handle_response(resp) + if self._done and self._inflight <= 0: + break +``` + +Uses `recv()` exclusively — no `poll()` spin. The ZMQ fd is registered with +the event loop, so `recv()` wakes exactly when a response is available with +zero CPU overhead. Each `recv()` call yields to the event loop, ensuring +strategy coroutines (call_at callbacks, semaphore waiters) are never starved. + +For `ConcurrencyStrategy`, `_handle_response` calls `strategy.on_query_complete()` +which releases the semaphore. Since `recv()` returns as soon as the fd is readable +and `eager_task_factory` executes the woken semaphore waiter synchronously, there +is no added latency compared to a poll-based approach. + +**`_handle_response(resp)`**: + +- `QueryResult`: publish COMPLETE event, decrement `_inflight`, call `on_sample_complete`, + call `strategy.on_query_complete(query_id)` if strategy supports it +- `StreamChunk(first)`: publish RECV_FIRST event +- `StreamChunk(non-first)`: publish RECV_NON_FIRST event + +**Timestamp fidelity:** + +- ISSUED: `monotonic_ns()` taken immediately before `issuer.issue()`. The ZMQ push is + sync and non-blocking, so this honestly represents when the query entered the transport. +- COMPLETE: `QueryResult.completed_at` is set via `force_setattr(monotonic_ns())` in + `__post_init__`, regenerated on deserialization. Both ISSUED and COMPLETE timestamps + share the same ZMQ transit bias. TTFT (`RECV_FIRST - ISSUED`) is still sensitive + to this overhead since it spans the full ZMQ round-trip. TPOT avoids cross-process + clock skew by computing time deltas between consecutive chunks within the same process. + +### LoadStrategy (Protocol) + +**File:** `src/inference_endpoint/load_generator/strategy.py` + +```python +class LoadStrategy(Protocol): + async def execute( + self, + phase_issuer: PhaseIssuer, + ) -> int: + """Drive sample issuance. Returns count of samples issued. + + Call phase_issuer.issue(sample_index) for each sample. + Returns None when session is stopping (max_duration, stop(), or + all samples exhausted). + """ + ... + + def on_query_complete(self, query_id: str) -> None: + """Called by session on each QueryResult. Default: no-op.""" + ... +``` + +The strategy calls `phase_issuer.issue(idx)` which handles data loading, Query +construction, event publishing, and the actual send. The strategy only controls +_when_ and _which index_ to issue. Stop checking is internal to `PhaseIssuer.issue()` +— it returns `None` when the session should stop. + +`on_query_complete` is the hook for `ConcurrencyStrategy` — other strategies ignore it. + +### TimedIssueStrategy + +Handles `LoadPatternType.POISSON`. Default uses `loop.call_at`; opt-in +`run_in_executor(busy_wait)` available for sub-100μs precision requirements. + +```python +class TimedIssueStrategy(LoadStrategy): + def __init__( + self, + delay_fn: Callable[[], int], + sample_order: Iterator[int], + loop: asyncio.AbstractEventLoop, + use_executor: bool = False, + ): ... + + async def execute(self, phase_issuer: PhaseIssuer) -> int: + if self.use_executor: + return await self._execute_executor(phase_issuer) + else: + return await self._execute_call_at(phase_issuer) +``` + +**call_at mode** (default): + +```python +async def _execute_call_at(self, phase_issuer): + done = asyncio.Event() + start_time = self._loop.time() + cumulative_s = 0.0 + + def schedule_next(): + nonlocal cumulative_s + idx = next(self.sample_order, None) + if idx is None: + done.set() + return + cumulative_s += self.delay_fn() / 1e9 + self._loop.call_at(start_time + cumulative_s, fire, idx) + + def fire(idx): + if phase_issuer.issue(idx) is None: + done.set() + return + schedule_next() + + schedule_next() + await done.wait() + return phase_issuer.issued_count +``` + +**Executor mode** (opt-in, `use_executor=True`): + +```python +async def _execute_executor(self, phase_issuer): + start = monotonic_ns() + cumulative = 0 + for idx in self.sample_order: + cumulative += self.delay_fn() + target = start + cumulative + now = monotonic_ns() + if target > now: + await self._loop.run_in_executor(None, _busy_wait_until, target) + if phase_issuer.issue(idx) is None: + break + return phase_issuer.issued_count +``` + +### BurstStrategy + +Handles `LoadPatternType.MAX_THROUGHPUT`. Issues all samples as fast as possible +using `loop.call_soon` to schedule each issue as an event loop callback. This +avoids starving the receiver — between each callback, the loop processes I/O +events (including ZMQ recv fd readiness). + +```python +class BurstStrategy(LoadStrategy): + def __init__(self, sample_order: Iterator[int], loop: asyncio.AbstractEventLoop): ... + + async def execute(self, phase_issuer: PhaseIssuer) -> int: + done = asyncio.Event() + + def issue_next(): + idx = next(self.sample_order, None) + if idx is None or phase_issuer.issue(idx) is None: + done.set() + return + self._loop.call_soon(issue_next) + + self._loop.call_soon(issue_next) + await done.wait() + return phase_issuer.issued_count +``` + +Each `call_soon` yields to the event loop between issues, preventing receiver +starvation. Benchmark data shows `loop.call_at` (with zero delay, equivalent +to `call_soon`) achieves 104k QPS — the highest throughput of all strategies. + +### ConcurrencyStrategy + +Handles `LoadPatternType.CONCURRENCY`. Semaphore-gated by completions. + +```python +class ConcurrencyStrategy(LoadStrategy): + def __init__(self, target_concurrency: int, sample_order: Iterator[int]): ... + + async def execute(self, phase_issuer: PhaseIssuer) -> int: + for idx in self.sample_order: + await self._sem.acquire() + if phase_issuer.issue(idx) is None: + self._sem.release() + break + return phase_issuer.issued_count + + def on_query_complete(self, query_id: str) -> None: + self._sem.release() +``` + +### SampleIssuer (Protocol) + +```python +class SampleIssuer(Protocol): + def issue(self, query: Query) -> None: ... + def poll(self) -> QueryResult | StreamChunk | None: ... + async def recv(self) -> QueryResult | StreamChunk | None: ... + def shutdown(self) -> None: ... +``` + +`issue()` is sync (ZMQ push). `poll()` is non-blocking sync drain. `recv()` is +async blocking wait. This matches `HTTPEndpointClient`'s existing interface. + +### SampleOrder (unchanged) + +`SampleOrder` is an infinite iterator yielding dataset indices. Implementations: + +- `WithoutReplacementSampleOrder` — shuffle, exhaust, reshuffle +- `WithReplacementSampleOrder` — uniform random + +Termination is controlled by `BenchmarkSession._should_stop()`, not the iterator. + +### SessionResult + +```python +@dataclass(frozen=True, slots=True) +class PhaseResult: + """Result of a single benchmark phase.""" + name: str + phase_type: PhaseType + uuid_to_index: dict[str, int] + report: Report | None # Only for PERFORMANCE phases + start_time_ns: int + end_time_ns: int + + +@dataclass(frozen=True, slots=True) +class SessionResult: + """Combined results from all phases in a session.""" + session_id: str + phase_results: list[PhaseResult] + start_time_ns: int + end_time_ns: int + + @property + def perf_results(self) -> list[PhaseResult]: + return [r for r in self.phase_results if r.phase_type == PhaseType.PERFORMANCE] + + @property + def accuracy_results(self) -> list[PhaseResult]: + return [r for r in self.phase_results if r.phase_type == PhaseType.ACCURACY] +``` + +--- + +## Data Flow + +### Happy Path: Issue → Response → Event + +```mermaid +sequenceDiagram + participant S as LoadStrategy + participant B as BenchmarkSession + participant D as Dataset + participant I as SampleIssuer + participant W as Worker Process + participant E as EventPublisher + participant M as MetricsAggregator + + S->>B: issue_fn(sample_index) + B->>D: load_sample(index) + D-->>B: sample_data + Note over B: Build Query(id=uuid4().hex, data=load_sample(idx)) + B->>E: publish(ISSUED, uuid, timestamp_ns) + E->>M: ZMQ PUB (EventRecord) + B->>I: issue(query) + I->>W: ZMQ PUSH (Query) + W->>W: HTTP request → endpoint + W-->>I: ZMQ PUSH (QueryResult) + I-->>B: poll() / recv() + B->>E: publish(COMPLETE, uuid, completed_at) + E->>M: ZMQ PUB (EventRecord) + B->>S: on_query_complete(uuid) + Note over S: ConcurrencyStrategy: sem.release() +``` + +### Multi-Phase Session Lifecycle + +```mermaid +sequenceDiagram + participant C as Caller (execute.py) + participant B as BenchmarkSession + participant E as EventPublisher + participant M as MetricsAggregator + participant K as KVStoreReader + + C->>B: run(phases) + B->>E: STARTED + + Note over B: === Saturation Phase === + B->>B: execute strategy (untracked, no drain) + + Note over B: === Perf Phase 1 (e.g. QPS=1000) === + B->>E: START_PERFORMANCE_TRACKING + B->>B: execute strategy + B->>B: drain in-flight + B->>E: STOP_PERFORMANCE_TRACKING + B->>K: snapshot metrics → Report + Note over B: PhaseResult("perf_qps1k", report) + + Note over B: === Perf Phase 2 (e.g. QPS=5000) === + B->>E: START_PERFORMANCE_TRACKING + B->>B: execute strategy + B->>B: drain in-flight + B->>E: STOP_PERFORMANCE_TRACKING + B->>K: snapshot metrics → Report + Note over B: PhaseResult("perf_qps5k", report) + + Note over B: === Accuracy Phase === + B->>B: execute strategy (untracked) + B->>B: drain in-flight + Note over B: PhaseResult("accuracy", uuid_map) + + B->>E: ENDED + B-->>C: SessionResult +``` + +### Separate Timer Process Data Flow + +```mermaid +sequenceDiagram + participant T as Timer Process + participant B as BenchmarkSession + participant D as Dataset + participant I as SampleIssuer + participant W as Worker Process + + Note over T: Busy-wait loop (no GIL contention) + T->>B: ZMQ PUSH (sample_index) + B->>D: load_sample(index) + D-->>B: sample_data + Note over B: Build Query, publish ISSUED + B->>I: issue(query) + I->>W: ZMQ PUSH + W-->>I: ZMQ PUSH (QueryResult) + I-->>B: poll() / recv() + Note over B: publish COMPLETE +``` + +--- + +## Event Loop Topology + +### Standard (single process) + +```mermaid +graph TD + subgraph "Main Process — LoopManager.default_loop (uvloop)" + A["BenchmarkSession.run()"] + B["LoadStrategy.execute()"] + C["_receive_responses() task"] + D["EventPublisher (ZMQ PUB)"] + E["HTTPEndpointClient (shared loop)"] + + A --> B + A --> C + B -->|"issue_fn → issuer.issue()"| E + C -->|"poll() / recv()"| E + B --> D + C --> D + end + + subgraph "Worker Process 1" + W1["HTTP → endpoint"] + end + subgraph "Worker Process N" + WN["HTTP → endpoint"] + end + subgraph "MetricsAggregator (subprocess)" + MA["ZmqEventRecordSubscriber"] + KB["KVStore (mmap)"] + MA --> KB + end + + E -->|"ZMQ IPC"| W1 + E -->|"ZMQ IPC"| WN + W1 -->|"ZMQ IPC"| E + WN -->|"ZMQ IPC"| E + D -->|"ZMQ PUB"| MA +``` + +### With Separate Timer Process + +```mermaid +graph TD + subgraph "Timer Process" + T["Busy-wait loop + ZMQ PUSH"] + end + + subgraph "Main Process — LoopManager.default_loop" + R["ZMQ PULL receiver"] + A["BenchmarkSession"] + C["_receive_responses() task"] + D["EventPublisher"] + E["HTTPEndpointClient"] + + R -->|"sample_index"| A + A -->|"issue()"| E + C -->|"poll/recv"| E + A --> D + C --> D + end + + subgraph "Worker Processes" + W["HTTP → endpoint"] + end + subgraph "MetricsAggregator" + MA["Subscriber → KVStore"] + end + + T -->|"ZMQ IPC"| R + E -->|"ZMQ IPC"| W + W -->|"ZMQ IPC"| E + D -->|"ZMQ PUB"| MA +``` + +--- + +## Load Pattern Mapping + +```python +def create_load_strategy( + runtime_settings: RuntimeSettings, + sample_order: SampleOrder, + loop: asyncio.AbstractEventLoop, + use_executor: bool = False, + use_timer_process: bool = False, +) -> LoadStrategy: + lp = runtime_settings.load_pattern + + match lp.type: + case LoadPatternType.MAX_THROUGHPUT: + return BurstStrategy(sample_order, loop) + + case LoadPatternType.POISSON: + delay_fn = poisson_delay_fn(lp.target_qps, runtime_settings.rng_sched) + if use_timer_process: + return TimerProcessStrategy(delay_fn, sample_order) + return TimedIssueStrategy(delay_fn, sample_order, loop, + use_executor=use_executor) + + case LoadPatternType.CONCURRENCY: + return ConcurrencyStrategy(lp.target_concurrency, sample_order) +``` + +--- + +## Benchmark Data Summary + +From `.cursor_artifacts/async_lg_benchmarks/` (MaxThroughputServer + real HTTPEndpointClient): + +### Poisson Mode — Strategy Comparison + +| QPS | `run_in_executor` precision | `loop.call_at` precision | `asyncio.sleep` precision | +| ------ | --------------------------- | ------------------------ | ------------------------- | +| 100 | 84 μs | 1,772 μs | 2,008 μs | +| 1,000 | 65 μs | 679 μs | 734 μs | +| 10,000 | 67 μs | 739 μs | 658 μs | +| 50,000 | 85 μs | 586 μs | 291 μs | +| 100k | 126 μs | 1,043 μs | 65 μs | + +Response latency at 100 QPS: `run_in_executor` = 6.2ms, `loop.call_at` = 1.4ms. +The GIL contention from the executor busy-wait thread penalizes low-QPS latency. + +### Concurrency Mode + +| Strategy | QPS | Latency (mean) | +| --------- | ------ | -------------- | +| Semaphore | 80,631 | 0.73 ms | +| Callback | 77,488 | 0.81 ms | + +### Max Throughput + +| Strategy | QPS | Latency (mean) | +| -------------- | ------- | -------------- | +| `loop.call_at` | 104,039 | 1.47 ms | +| `run_in_exec` | 78,261 | 8.28 ms | + +--- + +## Removed Constructs + +| Removed | Reason | +| ----------------------------------------------- | --------------------------------------------- | +| `Sample` class | Replaced by `Query` (frozen `msgspec.Struct`) | +| `Sample.__setattr__` hack | UUID generated before `Query` construction | +| `SampleEventHandler` singleton | Events via `EventPublisher` ZMQ PUB/SUB | +| `IssuedSample` dataclass | `uuid_to_index` dict on session is sufficient | +| `Scheduler` class hierarchy | Replaced by `LoadStrategy` + factory function | +| `LoadGenerator` / `SchedulerBasedLoadGenerator` | Replaced by `LoadStrategy` | +| `threading.Thread` in `BenchmarkSession` | Fully async | +| `threading.Condition` in `ConcurrencyScheduler` | `asyncio.Semaphore` | +| `HttpClientSampleIssuer._handle_responses` | Session owns the receive loop | + +--- + +## File Structure + +``` +src/inference_endpoint/load_generator/ +├── __init__.py # Public exports +├── session.py # BenchmarkSession, SessionResult +├── strategy.py # LoadStrategy protocol, TimedIssueStrategy, +│ # BurstStrategy, ConcurrencyStrategy, +│ # create_load_strategy() +├── sample_order.py # SampleOrder, WithoutReplacement, WithReplacement +└── delay.py # poisson_delay_fn, uniform_delay_fn +``` + +Deleted: + +- `load_generator.py` (LoadGenerator, SchedulerBasedLoadGenerator) +- `scheduler.py` (Scheduler hierarchy) +- `sample.py` (Sample, SampleEventHandler, IssuedSample) + +--- + +## Integration Points + +### HTTPEndpointClient + +Pass `loop` to share the event loop. The client already supports this via the +`_owns_loop` flag. Two changes required: + +**Initialization deadlock:** `__init__` calls `run_coroutine_threadsafe().result()` +which deadlocks when the calling thread IS the event loop thread. Fix: add an +async classmethod factory: + +```python +@classmethod +async def create(cls, config: HTTPClientConfig, loop: asyncio.AbstractEventLoop) -> HTTPEndpointClient: + client = cls.__new__(cls) + client._setup_sync_fields(config, loop) + await client._initialize() + return client +``` + +**Shutdown deadlock:** Same pattern — `shutdown()` calls `run_coroutine_threadsafe().result()`. +Fix: expose `async shutdown_async()` as a public method. When `_owns_loop is False`, +`shutdown()` should raise if called from the event loop thread, directing callers +to use `await shutdown_async()`. + +### EventPublisher / MetricsAggregator + +Session publishes `EventRecord` instances via `ZmqEventRecordPublisher`. The publisher +uses non-blocking ZMQ send with fd-based writer fallback — safe to call from sync +callbacks (like `call_at` fire functions). + +> **Key fix:** `Report.from_kv_reader` currently reads counter keys (`n_samples_issued`, +> `duration_ns`) that don't match the `MetricCounterKey` enum written by the aggregator +> (`total_samples_issued`, `tracked_duration_ns`). Must update `from_kv_reader` to use +> the actual key names. Performance reports should use `tracked_*` counters. A +> `test_started_at` counter must be added to the aggregator (set on `SessionEventType.STARTED`). + +### HttpClientSampleIssuer Migration + +The current issuer takes `Sample` and constructs `Query` internally. In the new design, +`PhaseIssuer` constructs the `Query`, so the issuer just forwards it: + +```python +class HttpClientSampleIssuer: + def __init__(self, http_client: HTTPEndpointClient): + self.http_client = http_client + + def issue(self, query: Query) -> None: + self.http_client.issue(query) + + def poll(self) -> QueryResult | StreamChunk | None: + return self.http_client.poll() + + async def recv(self) -> QueryResult | StreamChunk | None: + return await self.http_client.recv() + + def shutdown(self) -> None: + pass # HTTPEndpointClient shutdown called separately +``` + +Removed from current issuer: `_handle_responses` coroutine, `SampleEventHandler` +routing, `run_coroutine_threadsafe` cross-loop dispatch. The session's +`_receive_responses` replaces all of this. + +### Query.id Format + +`Query.default_factory` uses `str(uuid.uuid4())` (36 chars with hyphens). +The design uses `uuid.uuid4().hex` (32 chars, no hyphens). Standardize on +`.hex` — shorter strings, no parsing overhead. Update `Query.default_factory` +to match. + +### Timestamp Fidelity + +- **ISSUED**: `monotonic_ns()` taken in `PhaseIssuer.issue()` immediately before + `issuer.issue(query)`. ZMQ push is sync/non-blocking — timestamp is honest. +- **COMPLETE**: `QueryResult.completed_at` set in `__post_init__` on deserialization + in the main process. Measures main-process receipt time, not worker-side completion. +- **TTFT**: `RECV_FIRST - ISSUED` includes full round-trip ZMQ overhead (outbound to + worker + return to main). This adds 20-100μs of systematic bias. Acceptable for + most benchmarks; document as a known measurement overhead. +- **Latency (COMPLETE - ISSUED)**: Both timestamps taken on the main process side. + ZMQ transit bias is symmetric and cancels. This is the most accurate measurement. + +### Stale Completions After Saturation + +After a saturation phase (no drain), in-flight responses arrive during the perf +phase. The receiver must distinguish stale vs current-phase completions: + +```python +def _handle_response(self, resp: QueryResult) -> None: + query_id = resp.id + # Always publish the event (aggregator tracks all samples) + self._publisher.publish(EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=resp.completed_at, + sample_uuid=query_id, + )) + # Only route to current phase strategy if this is a current-phase query + if query_id in self._current_phase_issuer.uuid_to_index: + self._current_phase_issuer.inflight -= 1 + if self._current_strategy: + self._current_strategy.on_query_complete(query_id) + # Stale completions: event published but strategy/inflight not affected +``` + +Same guard applies to `StreamChunk` with `is_complete=True` — check +`uuid_to_index` membership before decrementing inflight. Non-final +StreamChunks don't affect inflight and can be published unconditionally. + +### Sync Per-Sample Work in Callbacks + +All three strategies call `PhaseIssuer.issue()` synchronously — from `call_at` +callbacks (Poisson), `call_soon` callbacks (Burst), or inline after `sem.acquire()` +(Concurrency). Each `issue()` call performs: `dataset.load_sample()`, `uuid4().hex`, +`Query` construction, `EventRecord` publish (ZMQ NOBLOCK), and `issuer.issue()` +(ZMQ NOBLOCK). The ZMQ sends are confirmed non-blocking with internal buffering. + +The dominant cost is `dataset.load_sample()`. **Requirement:** datasets must be +pre-loaded into memory before the benchmark starts. If `load_sample()` performs +disk I/O, it blocks the event loop and degrades both timing precision and response +processing. For lazy-loading or disk-backed datasets, either pre-materialize +during setup or use executor mode. + +At 100k+ QPS with `BurstStrategy`, the `call_soon` callback queue depth can +delay `recv()` wakeups. Benchmarking shows this is acceptable (104k QPS with +1.47ms mean response latency), but the recv latency is bounded by the queue +depth rather than being strictly real-time. + +--- + +## CLI / Logging / TUI Integration + +### CLI Integration + +The CLI entry point (`commands/benchmark/execute.py`) orchestrates setup, execution, +and finalization as three sync phases: + +```python +def run_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> None: + ctx = setup_benchmark(config, test_mode) # sync: datasets, tokenizer, config + result, collector = run_benchmark_async(ctx) # async: issuance on event loop + finalize_benchmark(ctx, result, collector) # sync: scoring, report, JSON output +``` + +`run_benchmark_async` is called via `loop.run_until_complete()` from the main thread. +Signal handling uses `loop.add_signal_handler(signal.SIGINT, session.stop)` which +only works on the main thread — this is enforced by the sync-to-async boundary. + +```python +def run_benchmark_async(ctx: BenchmarkContext) -> tuple[SessionResult, ResponseCollector]: + loop = LoopManager().default_loop + return loop.run_until_complete(_run_benchmark_async(ctx, loop)) + +async def _run_benchmark_async( + ctx: BenchmarkContext, + loop: asyncio.AbstractEventLoop, +) -> tuple[SessionResult, ResponseCollector]: + collector = ResponseCollector(collect_responses=ctx.collect_responses) + + # Setup: HTTP client, event publisher, service subprocesses + client = await HTTPEndpointClient.create(ctx.http_config, loop) + issuer = HttpClientSampleIssuer(client) + zmq_ctx = ManagedZMQContext() + publisher = ZmqEventRecordPublisher(pub_socket_name, zmq_ctx, loop=loop) + launcher = ServiceLauncher(zmq_ctx) + await launcher.launch([ + ServiceConfig( + module="inference_endpoint.async_utils.services.metrics_aggregator", + args=["--socket-dir", zmq_ctx.socket_dir, + "--socket-name", pub_socket_name, + "--metrics-dir", str(metrics_dir)], + ), + ]) + + session = BenchmarkSession( + issuer=issuer, + event_publisher=publisher, + on_sample_complete=collector.on_complete, + ) + + # Build phases from config + phases = _build_phases(ctx) + + loop.add_signal_handler(signal.SIGINT, session.stop) + try: + result = await session.run(phases) + except Exception as e: + raise ExecutionError(f"Benchmark execution failed: {e}") from e + finally: + loop.remove_signal_handler(signal.SIGINT) + await client.shutdown_async() + launcher.wait_for_exit() + publisher.close() + zmq_ctx.cleanup() + + return result, collector +``` + +### Logging + +Standard Python `logging` is used throughout. Key log points: + +- Phase transitions: `logger.info("Starting phase: %s (%s)", name, phase_type)` +- Sample counts: `logger.info("Phase %s complete: %d samples issued", name, count)` +- Errors: `logger.error("Failed to issue query %s: %s", query_id, error)` +- Shutdown: `logger.info("Benchmark session cancelled")` + +Log level is configurable via `RuntimeSettings` / CLI `--log-level`. + +### Progress Reporting (tqdm) + +`ResponseCollector.on_complete` drives the progress bar: + +```python +class ResponseCollector: + def __init__(self, collect_responses: bool, pbar: tqdm | None = None): + self.responses: dict[str, str] = {} + self.errors: list[str] = [] + self.pbar = pbar + + def on_complete(self, result: QueryResult) -> None: + if result.error: + self.errors.append(f"{result.id}: {result.error}") + elif self.responses is not None: + self.responses[result.id] = result.get_response_output_string() + if self.pbar: + self.pbar.update(1) +``` + +The session calls `on_sample_complete(result)` from `_handle_response`, which +fires from the `_receive_responses` coroutine on the event loop. + +### Accuracy Phase Response Collection + +After `session.run()` returns, accuracy responses are partitioned using +`PhaseResult.uuid_to_index`: + +```python +for phase_result in result.accuracy_results: + phase_responses = { + uid: collector.responses[uid] + for uid in phase_result.uuid_to_index + if uid in collector.responses + } + score = scorer.score(phase_responses, phase_result.uuid_to_index, acc_dataset) +``` + +### Future: TUI Integration + +The planned TUI architecture moves the benchmark engine (HTTPClient + load generator) +to a child process, with the TUI as the foreground process reading periodic reports. + +``` +TUI Process (foreground): + - Renders live dashboard (throughput, latency, progress) + - Reads BasicKVStoreReader for real-time metrics from /dev/shm + - Receives SessionResult via IPC on completion + +Benchmark Process (child): + - Runs BenchmarkSession on its own event loop + - Writes metrics via EventPublisher -> MetricsAggregator -> KVStore + - Returns SessionResult to parent via IPC (pickle over pipe / ZMQ) +``` + +This architecture is enabled by the current design's clean separation: + +- **KVStore** is already cross-process readable (mmap on /dev/shm) +- **BenchmarkSession** has no UI dependencies — it takes callbacks +- **SessionResult** is a frozen dataclass, trivially serializable +- The `on_sample_complete` callback would not be used in TUI mode (no + cross-process callback). Instead, the TUI polls KVStoreReader for + `tracked_samples_completed` to update the progress display. + +The TUI process can also read the `Report` from the KVStore at any time for +live intermediate reports (current QPS, latency distribution so far), not +just the final report. + +> **Constraint:** The benchmark child process must be a **non-daemon** OS process +> (e.g., `subprocess.Popen` or `multiprocessing.Process(daemon=False)`). +> `HTTPEndpointClient` spawns worker processes via `WorkerManager`, and those +> workers are `daemon=True`. Python prohibits daemon processes from spawning +> children — if the benchmark process is itself a daemon, worker creation fails. + +> **TODO:** Signal forwarding: SIGINT from the terminal goes to the foreground +> process group (TUI). The TUI must forward a stop signal to the benchmark +> child process (e.g., via `process.terminate()` or a ZMQ control channel). +> Design the stop protocol during TUI implementation. + +--- + +## Multi-Perf Sweep Example + +Concurrency sweep against same endpoint: + +```python +phases = [ + PhaseConfig("sat_c16", sat_settings(16), ds, PhaseType.SATURATION), + PhaseConfig("perf_c16", perf_settings(16), ds, PhaseType.PERFORMANCE), + PhaseConfig("sat_c32", sat_settings(32), ds, PhaseType.SATURATION), + PhaseConfig("perf_c32", perf_settings(32), ds, PhaseType.PERFORMANCE), + PhaseConfig("sat_c64", sat_settings(64), ds, PhaseType.SATURATION), + PhaseConfig("perf_c64", perf_settings(64), ds, PhaseType.PERFORMANCE), + PhaseConfig("accuracy", acc_settings, acc_ds, PhaseType.ACCURACY), +] +result = await session.run(phases) + +for pr in result.perf_results: + print(f"{pr.name}: {pr.report.qps():.0f} QPS") +``` + +--- + +## Rejected Alternatives + +| Alternative | Rejected Because | +| --------------------------------------- | ------------------------------------------------------------------------------------------ | +| Unified strategy for all patterns | Benchmark data shows each pattern benefits from different async primitives | +| `asyncio.Semaphore` for all concurrency | Correct for CONCURRENCY mode, but overhead hurts MAX_THROUGHPUT | +| `run_in_executor` for all timing | GIL contention causes 6ms latency at low QPS | +| `asyncio.sleep` for all timing | 700μs precision at mid-range QPS, `run_in_executor` is 10x better | +| Direct busy-wait on event loop | Starves receiver — 26ms response latency vs 0.6ms | +| Callback-based concurrency | Semaphore is simpler and benchmarks slightly better with real ZMQ client | +| Shared `Scheduler` base class | Concurrency has no delay concept; forcing it conflates distinct semantics | +| Separate `BenchmarkOrchestrator` | Phase sequencing is simple enough to live in `BenchmarkSession.run()` | +| poll()-based receiver spin | Starves event loop during response bursts; pure recv() is fd-driven with zero CPU overhead | diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/fs_check.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/fs_check.py new file mode 100644 index 00000000..fea99811 --- /dev/null +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/fs_check.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Filesystem type detection for mmap ordering decisions. + +On tmpfs (/dev/shm), msync() is a no-op because there is no backing store. +On a real on-disk filesystem, msync() flushes dirty pages to the shared page +cache, which provides write ordering for cross-process mmap readers. + +On ARM (weak memory model), we need msync() to act as an ordering mechanism +between the value write and the count update in _SeriesItem.append(). This +only works on a real filesystem — not tmpfs. Detecting the filesystem type +lets us: + - Skip the useless msync() syscall on tmpfs (any architecture) + - Warn if ARM code is running on tmpfs (msync won't provide ordering) +""" + +from __future__ import annotations + +import ctypes +import ctypes.util +import logging +import platform +from pathlib import Path + +logger = logging.getLogger(__name__) + +_TMPFS_MAGIC = 0x01021994 +"""Special tmpfs filesystem header value.""" + + +def _is_tmpfs_via_statfs(path: str) -> bool | None: + """Check filesystem type via libc statfs(2). Returns None if unavailable.""" + try: + lib_name = ctypes.util.find_library("c") + if lib_name is None: + return None + libc = ctypes.CDLL(lib_name, use_errno=True) + + # Allocate a large buffer to account for differently sized statfs + # structs across architectures. f_type is always the first field + # (__SWORD_TYPE / long) at offset 0 on all Linux archs. + buf = ctypes.create_string_buffer(256) + if libc.statfs(path.encode(), buf) != 0: + return None + # f_type is a native-endian long at offset 0 + f_type = ctypes.c_long.from_buffer(buf, 0).value + return f_type == _TMPFS_MAGIC + except (OSError, AttributeError, ValueError): + return None + + +def _is_tmpfs_via_proc_mounts(path: str) -> bool | None: + """Check filesystem type via /proc/mounts. Returns None if unavailable.""" + try: + resolved = str(Path(path).resolve()) + best_match = "" + best_fstype = "" + with open("/proc/mounts") as f: + for line in f: + parts = line.split() + if len(parts) < 3: + continue + mount_point, fstype = parts[1], parts[2] + if resolved.startswith(mount_point) and len(mount_point) > len( + best_match + ): + best_match = mount_point + best_fstype = fstype + if not best_match: + return None + return best_fstype == "tmpfs" + except OSError: + return None + + +def is_tmpfs(path: str | Path) -> bool: + """Check if a path resides on a tmpfs filesystem. + + Tries statfs(2) via ctypes first, falls back to /proc/mounts. + Returns False if detection fails (safe default — will call msync). + """ + path_str = str(path) + + result = _is_tmpfs_via_statfs(path_str) + if result is not None: + return result + + result = _is_tmpfs_via_proc_mounts(path_str) + if result is not None: + return result + + logger.warning( + "Could not determine filesystem type for %s " + "(statfs and /proc/mounts both unavailable). " + "Assuming non-tmpfs (msync will be called on every series append).", + path_str, + ) + return False + + +def needs_msync(path: str | Path) -> bool: + """Determine if msync() is needed for mmap write ordering at this path. + + Returns True if msync should be called between value write and count + update in series append. This is needed on ARM when the backing store + is a real filesystem (not tmpfs). + + On x86-64 (TSO), store ordering is guaranteed by hardware — msync is + never needed regardless of filesystem type. + + On ARM with tmpfs, msync is a no-op and won't help — log a warning + since the caller should use an on-disk directory for correct ordering. + """ + if platform.machine() == "x86_64": + return False + + on_tmpfs = is_tmpfs(path) + if on_tmpfs: + logger.warning( + "ARM platform with tmpfs-backed metrics at %s. " + "Python does not support memory fences. " + "Use an on-disk metrics directory for correct cross-process reads.", + path, + ) + return False + + return True diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/kv_store.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/kv_store.py index 0ef446c2..a5697007 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/kv_store.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/kv_store.py @@ -43,13 +43,14 @@ import math import mmap import os -import platform import shutil import struct from abc import ABC, abstractmethod from pathlib import Path from typing import Literal +from .fs_check import needs_msync + # --------------------------------------------------------------------------- # Series rollup stats (computed on read) # --------------------------------------------------------------------------- @@ -257,6 +258,7 @@ class _SeriesItem: "_dtype", "_char", "_fmt", + "_needs_msync", ) def __init__( @@ -272,6 +274,7 @@ def __init__( self._dtype = dtype self._char = _STRUCT_CHAR[dtype] self._fmt = f"{_ENDIAN}{self._char}" + self._needs_msync = needs_msync(path.parent) total = _HEADER_BYTES + capacity * _VALUE_BYTES fd = os.open(str(path), os.O_CREAT | os.O_RDWR, _DEFAULT_FILE_MODE) try: @@ -285,7 +288,7 @@ def append(self, value: int | float) -> None: if self._closed: logger.warning("append() called on closed series: %s", self._path) return - if type(value) != self._dtype: + if not isinstance(value, self._dtype): raise TypeError( f"Expected {self._dtype.__name__}, got {type(value).__name__}" ) @@ -293,25 +296,9 @@ def append(self, value: int | float) -> None: self._grow() offset = _HEADER_BYTES + self._count * _VALUE_BYTES struct.pack_into(self._fmt, self._mm, offset, value) - # NOTE: This flush() calls msync(), which is a no-op on tmpfs (/dev/shm) - # and does NOT act as a CPU memory barrier. On x86-64 (TSO), store ordering - # is guaranteed — the value write above is visible before the count update - # below without any explicit barrier. On ARM (weak memory model), a reader - # could observe the count update before the value write. To support ARM - # properly, Python's mmap doesn't expose memory fences; you would need - # ctypes to call libc's __sync_synchronize() or use atomic operations via - # a C extension. - # The primary safety guarantee is the single-writer protocol: - # readers only read up to the count they observed, and on the target - # platform (x86-64 Linux), TSO provides the required ordering. - # - # For ARM platforms: Prometheus integration is planned as a replacement - # for mmap-backed metrics. As a temporary workaround, an on-disk metrics - # directory can be used instead of tmpfs — msync will then act as a real - # flush, providing ordering at the cost of performance. - if platform.machine() != "x86_64": - # Do not flush on x86-64 to avoid a no-op syscall on every append() - # For ARM, flush() and use an on-disk metrics directory instead of tmpfs. + # Flush between value write and count update for cross-process ordering. + # See fs_check.needs_msync() for when this is needed and why. + if self._needs_msync: self._mm.flush() self._count += 1 struct.pack_into(" str: """ if scheme == "ipc": if self.socket_dir is None: - self._tmp_dir = tempfile.TemporaryDirectory(prefix="zmq_") + # Prefer /dev/shm for IPC sockets — overlayfs (common in + # containers for /tmp) does not support Unix sockets. + shm = Path("/dev/shm") + self._tmp_dir = tempfile.TemporaryDirectory( + prefix="zmq_", + dir=str(shm) if shm.is_dir() else None, + ) self.socket_dir = self._tmp_dir.name else: Path(self.socket_dir).mkdir(parents=True, exist_ok=True) diff --git a/src/inference_endpoint/async_utils/transport/zmq/ready_check.py b/src/inference_endpoint/async_utils/transport/zmq/ready_check.py index 66275607..3e2c35a7 100644 --- a/src/inference_endpoint/async_utils/transport/zmq/ready_check.py +++ b/src/inference_endpoint/async_utils/transport/zmq/ready_check.py @@ -133,8 +133,11 @@ async def wait(self, timeout: float | None = None) -> list[int]: len(identities), self._count, ) + except TimeoutError: + # Don't close socket on timeout — caller may retry. + raise except BaseException: - # Clean up socket on any failure (timeout, cancellation, etc.) + # Clean up socket on non-retryable failures (cancellation, etc.) self.close() raise diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 4854272a..d6948853 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -13,18 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Benchmark execution — phased architecture for threaded and future async runners. +"""Benchmark execution — phased architecture. Phases: - 1. setup_benchmark() — load tokenizer, dataset, scheduler (no IO) - 2. run_benchmark_threaded() — HTTP client + BenchmarkSession (threaded IO) + 1. setup_benchmark() — load tokenizer, dataset, config (no IO) + 2. run_benchmark_async() — HTTP client + async BenchmarkSession 3. finalize_benchmark() — accuracy scoring, results JSON """ from __future__ import annotations +import asyncio import json import logging +import shutil import signal import tempfile import uuid @@ -34,15 +36,28 @@ from typing import Any from urllib.parse import urljoin +import msgspec.json +from huggingface_hub import model_info from tqdm import tqdm -from transformers import AutoTokenizer from transformers.utils import logging as transformers_logging +from inference_endpoint.async_utils.loop_manager import LoopManager +from inference_endpoint.async_utils.services.launcher import ( + ServiceConfig, + ServiceLauncher, +) +from inference_endpoint.async_utils.services.metrics_aggregator.kv_store import ( + BasicKVStoreReader, +) +from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext +from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordPublisher from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( APIType, BenchmarkConfig, DatasetType, + LoadPattern, + LoadPatternType, StreamingMode, TestMode, TestType, @@ -60,13 +75,13 @@ InputValidationError, SetupError, ) -from inference_endpoint.load_generator import ( +from inference_endpoint.load_generator.session import ( BenchmarkSession, - SampleEvent, - SampleEventHandler, - WithoutReplacementSampleOrder, + PhaseConfig, + PhaseType, + SessionResult, ) -from inference_endpoint.load_generator.scheduler import Scheduler +from inference_endpoint.metrics.report import Report transformers_logging.set_verbosity_error() @@ -91,6 +106,7 @@ def __init__(self, collect_responses: bool = False, pbar: tqdm | None = None): self.pbar = pbar def on_complete_hook(self, result: QueryResult) -> None: + """Handle query completion (called once per query via QueryResult).""" self.count += 1 if result.error: self.errors.append(f"Sample {result.id}: {result.error}") @@ -102,6 +118,16 @@ def on_complete_hook(self, result: QueryResult) -> None: self.pbar.update(1) +@dataclass +class BenchmarkResult: + """Output of run_benchmark_async — all data needed for finalization.""" + + session: SessionResult + collector: ResponseCollector + report: Report | None + tmpfs_dir: Path + + @dataclass class AccuracyConfiguration: scorer: type[Scorer] @@ -123,10 +149,9 @@ class BenchmarkContext: config: BenchmarkConfig test_mode: TestMode report_dir: Path - tokenizer: AutoTokenizer | None + tokenizer_name: str | None dataloader: Dataset rt_settings: RuntimeSettings - scheduler: Scheduler total_samples: int accuracy_datasets: list[Dataset] = field(default_factory=list) eval_configs: list[AccuracyConfiguration] = field(default_factory=list) @@ -145,17 +170,37 @@ def enable_streaming(self) -> bool: return self.config.model_params.streaming == StreamingMode.ON -def _load_tokenizer(model_name: str) -> AutoTokenizer | None: - """Load HuggingFace tokenizer, warn on failure.""" +def _check_tokenizer_exists(model_name: str) -> bool: + """Check if a HuggingFace tokenizer exists for the model (API only, no download). + + Returns True if the model repo exists and has tokenizer files, False otherwise. + The actual tokenizer is loaded later by the MetricsAggregator subprocess and + by Harmony transforms (each loads their own instance as needed). + """ try: - logger.info(f"Loading tokenizer for model: {model_name}") - tokenizer = AutoTokenizer.from_pretrained(model_name) - logger.info("Tokenizer loaded successfully") - return tokenizer + info = model_info(model_name) + # Check for tokenizer files in the repo + siblings = {s.rfilename for s in (info.siblings or [])} + has_tokenizer = ( + "tokenizer_config.json" in siblings or "tokenizer.json" in siblings + ) + if has_tokenizer: + logger.info(f"Tokenizer available for model: {model_name}") + else: + logger.warning(f"Model {model_name} found but has no tokenizer files") + return has_tokenizer + except ImportError: + # huggingface_hub not installed — fall back to assuming it works + logger.info( + f"huggingface_hub not installed, assuming tokenizer exists for {model_name}" + ) + return True except Exception as e: - logger.warning(f"Failed to load tokenizer for {model_name}: {e}") - logger.warning("Continuing without tokenizer (report metrics may be limited)") - return None + logger.warning(f"Could not verify tokenizer for {model_name}: {e}") + logger.warning( + "Continuing without tokenizer (ISL/OSL/TPOT metrics will be unavailable)" + ) + return False def _load_datasets( @@ -228,22 +273,6 @@ def _load_datasets( return dataloader, accuracy_datasets, eval_configs -def _create_scheduler( - config: BenchmarkConfig, rt_settings: RuntimeSettings -) -> Scheduler: - """Create scheduler using __init_subclass__ registry.""" - load_pattern_type = config.settings.load_pattern.type - try: - scheduler_class = Scheduler.get_implementation(load_pattern_type) - scheduler = scheduler_class(rt_settings, WithoutReplacementSampleOrder) - logger.info( - f"Scheduler: {scheduler_class.__name__} (pattern: {load_pattern_type.value})" - ) - return scheduler - except KeyError as e: - raise SetupError(str(e)) from e - - def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext: """Load tokenizer, dataset, create scheduler, setup report dir.""" # CPU affinity @@ -260,8 +289,9 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo report_dir.mkdir(parents=True, exist_ok=True) config.to_yaml_file(report_dir / "config.yaml") - # Tokenizer (model name validated by BenchmarkConfig._resolve_and_validate) - tokenizer = _load_tokenizer(config.model_params.name) + # Tokenizer check (light API call, no download) + model_name = config.model_params.name + tokenizer_name = model_name if _check_tokenizer_exists(model_name) else None # Streaming logger.info( @@ -288,16 +318,13 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo f"Min Duration: {rt_settings.min_duration_ms / 1000:.1f}s, Expected samples: {total_samples}" ) - scheduler = _create_scheduler(config, rt_settings) - return BenchmarkContext( config=config, test_mode=test_mode, report_dir=report_dir, - tokenizer=tokenizer, + tokenizer_name=tokenizer_name, dataloader=dataloader, rt_settings=rt_settings, - scheduler=scheduler, total_samples=total_samples, accuracy_datasets=accuracy_datasets, eval_configs=eval_configs, @@ -305,95 +332,292 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo ) -def run_benchmark_threaded(ctx: BenchmarkContext) -> tuple[Any, ResponseCollector]: - """Run benchmark session with threaded HTTP client. Returns (report, collector).""" +def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: + """Build the phase list from BenchmarkContext.""" + phases: list[PhaseConfig] = [] + + # Performance phase + phases.append( + PhaseConfig( + "performance", ctx.rt_settings, ctx.dataloader, PhaseType.PERFORMANCE + ) + ) + + # Accuracy phases — use eval_cfg.dataset_name as phase name so it matches + # what Scorer._load_sample_index_map() looks up in sample_idx_map.json + for eval_cfg in ctx.eval_configs: + acc_ds = eval_cfg.dataset + acc_settings = RuntimeSettings( + metric_target=ctx.rt_settings.metric_target, + reported_metrics=ctx.rt_settings.reported_metrics, + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=acc_ds.num_samples(), + n_samples_to_issue=acc_ds.num_samples() * acc_ds.repeats, + min_sample_count=acc_ds.num_samples() * acc_ds.repeats, + rng_sched=ctx.rt_settings.rng_sched, + rng_sample_index=ctx.rt_settings.rng_sample_index, + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phases.append( + PhaseConfig(eval_cfg.dataset_name, acc_settings, acc_ds, PhaseType.ACCURACY) + ) + + return phases + + +def _setup_kv_reader( + metrics_dir: Path, + streaming: bool, +) -> BasicKVStoreReader: + """Create a KVStoreReader pre-registered with all metric keys.""" + reader = BasicKVStoreReader(metrics_dir) + # Counter keys (from MetricCounterKey enum) + for key in [ + "total_samples_issued", + "total_samples_completed", + "total_samples_failed", + "tracked_samples_issued", + "tracked_samples_completed", + "tracked_duration_ns", + "total_duration_ns", + ]: + reader.register_key(key, "counter") + # Series keys (from MetricSeriesKey enum) + for key in ["isl", "osl", "sample_latency_ns"]: + reader.register_key(key, "series") + reader.register_key("tpot_ns", "series", dtype=float) + if streaming: + for key in ["ttft_ns", "chunk_delta_ns"]: + reader.register_key(key, "series") + return reader + + +async def _run_benchmark_async( + ctx: BenchmarkContext, + loop: asyncio.AbstractEventLoop, +) -> BenchmarkResult: + """Run async benchmark session.""" config = ctx.config + session_id = f"cli_benchmark_{uuid.uuid4().hex[:8]}" - # Setup response collector + # Progress bar + response collector pbar = tqdm( desc=f"{config.model_params.name} (Streaming: {ctx.enable_streaming})", total=ctx.total_samples, - smoothing=0, # smoothing=0 shows average instead of EMA + smoothing=0, ) collector = ResponseCollector(collect_responses=ctx.collect_responses, pbar=pbar) - SampleEventHandler.register_hook(SampleEvent.COMPLETE, collector.on_complete_hook) - # Create endpoint client - endpoints = config.endpoint_config.endpoints - logger.info(f"Connecting: {endpoints}") - try: - api_type: APIType = config.endpoint_config.api_type - http_config = config.settings.client.with_updates( - endpoint_urls=[urljoin(e, api_type.default_route()) for e in endpoints], - api_type=api_type, - api_key=config.endpoint_config.api_key, - event_logs_dir=ctx.report_dir, - cpu_affinity=ctx.affinity_plan, + # ZMQ context for event publishing + service launcher + with ManagedZMQContext.scoped(io_threads=2) as zmq_ctx: + # Event publisher + pub_socket_name = f"ev_pub_{session_id}" + publisher = ZmqEventRecordPublisher(pub_socket_name, zmq_ctx, loop=loop) + + # Tmpfs directories for high-frequency writes (metrics mmap + event log) + # These are memory-backed; copied to report_dir on disk during finalization. + shm_base = ( + Path("/dev/shm") + if Path("/dev/shm").exists() + else Path(tempfile.gettempdir()) + ) + tmpfs_dir = shm_base / f"benchmark_{session_id}" + tmpfs_dir.mkdir(parents=True, exist_ok=True) + metrics_dir = tmpfs_dir / "metrics" + metrics_dir.mkdir(parents=True, exist_ok=True) + event_log_dir = tmpfs_dir / "events" + event_log_dir.mkdir(parents=True, exist_ok=True) + + # Launch service subprocesses + launcher = ServiceLauncher(zmq_ctx) + if zmq_ctx.socket_dir is None: + raise RuntimeError("ZMQ socket_dir must be set after publisher bind") + aggregator_args: list[str] = [ + "--socket-dir", + zmq_ctx.socket_dir, + "--socket-name", + pub_socket_name, + "--metrics-dir", + str(metrics_dir), + ] + if ctx.enable_streaming: + aggregator_args.append("--streaming") + if ctx.tokenizer_name is not None: + aggregator_args.extend(["--tokenizer", ctx.tokenizer_name]) + + # EventLoggerService writes events.jsonl to tmpfs (high-frequency writes) + event_logger_args: list[str] = [ + "--log-dir", + str(event_log_dir), + "--socket-dir", + zmq_ctx.socket_dir, + "--socket-name", + pub_socket_name, + "--writers", + "jsonl", + ] + + await launcher.launch( + [ + ServiceConfig( + module="inference_endpoint.async_utils.services.metrics_aggregator", + args=aggregator_args, + ), + ServiceConfig( + module="inference_endpoint.async_utils.services.event_logger", + args=event_logger_args, + ), + ], + timeout=30.0, ) - http_client = HTTPEndpointClient(http_config) - sample_issuer = HttpClientSampleIssuer(http_client) - except Exception as e: - raise SetupError(f"Failed to connect to endpoint: {e}") from e - # Run benchmark - logger.info("Running...") - sess = None - try: - sess = BenchmarkSession.start( - ctx.rt_settings, - ctx.dataloader, - sample_issuer, - ctx.scheduler, - name=f"cli_benchmark_{uuid.uuid4().hex[0:8]}", - report_dir=ctx.report_dir, - accuracy_datasets=ctx.accuracy_datasets, + # Create endpoint client on the shared loop + endpoints = config.endpoint_config.endpoints + logger.info(f"Connecting: {endpoints}") + http_client: HTTPEndpointClient | None = None + try: + api_type: APIType = config.endpoint_config.api_type + http_config = config.settings.client.with_updates( + endpoint_urls=[urljoin(e, api_type.default_route()) for e in endpoints], + api_type=api_type, + api_key=config.endpoint_config.api_key, + event_logs_dir=ctx.report_dir, + cpu_affinity=ctx.affinity_plan, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + except Exception as e: + pbar.close() + publisher.close() + launcher.kill_all() + raise SetupError(f"Failed to connect to endpoint: {e}") from e + + # Create session + session = BenchmarkSession( + issuer=issuer, + event_publisher=publisher, + loop=loop, + on_sample_complete=collector.on_complete_hook, + session_id=session_id, ) - # Wait for test end with ability to interrupt - def _raise_keyboard_interrupt(*_: object) -> None: - raise KeyboardInterrupt + phases = _build_phases(ctx) + report: Report | None = None - old_handler = signal.signal(signal.SIGINT, _raise_keyboard_interrupt) + loop.add_signal_handler(signal.SIGINT, session.stop) try: - sess.wait_for_test_end() + result = await session.run(phases) + except Exception as e: + raise ExecutionError(f"Benchmark execution failed: {e}") from e finally: - # Always restore original handler - signal.signal(signal.SIGINT, old_handler) - - # Prefer authoritative metrics from the session report - report = getattr(sess, "report", None) - if report is None: - raise ExecutionError("Session report missing — cannot produce results") - return report, collector + loop.remove_signal_handler(signal.SIGINT) + logger.info("Cleaning up...") + try: + if http_client: + await http_client.shutdown_async() + except Exception as e: + logger.warning(f"Client cleanup error: {e}") + publisher.close() + await asyncio.to_thread(launcher.wait_for_exit, 10.0) + + # Build report AFTER aggregator has exited — ensures all metrics + # (TTFT, TPOT, OSL, latency) are fully written to KVStore. + try: + kv_reader = _setup_kv_reader(metrics_dir, ctx.enable_streaming) + report = Report.from_kv_reader(kv_reader) + kv_reader.close() + except Exception as e: + logger.warning(f"Failed to build report from metrics: {e}") - except KeyboardInterrupt: - logger.warning("Benchmark interrupted by user") - raise - except ExecutionError: - # Re-raise our own exceptions - raise - except Exception as e: - raise ExecutionError(f"Benchmark execution failed: {e}") from e - finally: - # Cleanup - always execute - logger.info("Cleaning up...") - try: - if sess is not None: - sess.stop() pbar.close() - sample_issuer.shutdown() - http_client.shutdown() - except Exception as e: - logger.debug(f"Cleanup error: {e}") + + return BenchmarkResult( + session=result, collector=collector, report=report, tmpfs_dir=tmpfs_dir + ) -def finalize_benchmark( +def run_benchmark_async(ctx: BenchmarkContext) -> BenchmarkResult: + """Run async benchmark. Sync entry point — drives the event loop.""" + loop = LoopManager().default_loop + return loop.run_until_complete(_run_benchmark_async(ctx, loop)) + + +def _write_scoring_artifacts( ctx: BenchmarkContext, - report: Any, - collector: ResponseCollector, + result: SessionResult, + tmpfs_dir: Path, ) -> None: + """Write sample_idx_map.json and copy events.jsonl for Scorer consumption. + + events.jsonl is written by EventLoggerService to tmpfs during the benchmark. + We copy it to report_dir (typically on disk) during finalization. + """ + + # sample_idx_map.json — {dataset_name: {uuid: sample_index}} + sample_idx_map: dict[str, dict[str, int]] = {} + for phase_result in result.phase_results: + sample_idx_map[phase_result.name] = phase_result.uuid_to_index + + map_path = ctx.report_dir / "sample_idx_map.json" + with map_path.open("wb") as f: + f.write(msgspec.json.format(msgspec.json.encode(sample_idx_map), indent=2)) + logger.debug(f"Wrote {map_path}") + + # Copy events.jsonl from tmpfs to report_dir + _salvage_tmpfs(ctx.report_dir, tmpfs_dir) + + # Clean up tmpfs + shutil.rmtree(tmpfs_dir, ignore_errors=True) + + +def _salvage_tmpfs(report_dir: Path, tmpfs_dir: Path) -> None: + """Copy all salvageable artifacts from tmpfs to report_dir. + + Called during normal finalization and on interrupt/crash to preserve logs. + Safe to call multiple times (skips if already copied or tmpfs is gone). + """ + if not tmpfs_dir.exists(): + return + + # events.jsonl (from EventLoggerService) + src_events = tmpfs_dir / "events" / "events.jsonl" + if src_events.exists(): + dst_events = report_dir / "events.jsonl" + shutil.copy2(src_events, dst_events) + logger.debug(f"Copied {src_events} -> {dst_events}") + + # metrics mmap files (from MetricsAggregator KVStore) + src_metrics = tmpfs_dir / "metrics" + if src_metrics.exists(): + dst_metrics = report_dir / "metrics" + dst_metrics.mkdir(parents=True, exist_ok=True) + for f in src_metrics.iterdir(): + if f.is_file(): + shutil.copy2(f, dst_metrics / f.name) + logger.debug(f"Copied metrics from {src_metrics} -> {dst_metrics}") + + +def finalize_benchmark(ctx: BenchmarkContext, bench: BenchmarkResult) -> None: """Score accuracy, aggregate results, write JSON.""" config = ctx.config + result = bench.session + collector = bench.collector + report = bench.report + + # Display report if available (from MetricsAggregator KVStore) + if report is not None: + report.display(fn=lambda s: logger.info(s), summary_only=True) + report.to_json(save_to=ctx.report_dir / "result_summary.json") + + # Write human-readable report.txt + report_txt = ctx.report_dir / "report.txt" + with report_txt.open("w") as f: + report.display(fn=lambda s: print(s, file=f)) + logger.info(f"Report written to {report_txt}") + + # Write scoring artifacts + copy event log from tmpfs to disk + _write_scoring_artifacts(ctx, result, bench.tmpfs_dir) # Accuracy scoring accuracy_scores: dict[str, Any] = {} @@ -417,15 +641,27 @@ def finalize_benchmark( } logger.info(f"Score for {eval_cfg.dataset_name}: {score} ({n_repeats} repeats)") - # Report metrics - elapsed = report.duration_ns / 1e9 if report.duration_ns is not None else 0.0 - total_issued = report.n_samples_issued - success = total_issued - report.n_samples_failed - qps = report.qps() or 0.0 - - logger.info(f"Completed in {elapsed:.1f}s") - logger.info(f"Results: {success}/{total_issued} successful") - logger.info(f"Estimated QPS: {qps:.1f}") + # Report metrics: prefer Report from KVStore, fall back to SessionResult + if report is not None and report.duration_ns is not None: + perf_elapsed = report.duration_ns / 1e9 + total_issued = report.n_samples_issued + n_errors = report.n_samples_failed + qps = report.qps() or 0.0 + else: + perf = result.perf_results[0] if result.perf_results else None + if perf: + perf_elapsed = (perf.end_time_ns - perf.start_time_ns) / 1e9 + total_issued = perf.issued_count + else: + perf_elapsed = (result.end_time_ns - result.start_time_ns) / 1e9 + total_issued = 0 + n_errors = len(collector.errors) + qps = total_issued / perf_elapsed if perf_elapsed > 0 else 0.0 + + logger.info(f"Completed in {perf_elapsed:.1f}s") + logger.info(f"Results: {max(0, total_issued - n_errors)}/{total_issued} successful") + if qps > 0: + logger.info(f"Estimated QPS: {qps:.1f}") if collector.errors: logger.warning(f"Errors: {len(collector.errors)}") @@ -444,9 +680,9 @@ def finalize_benchmark( }, "results": { "total": total_issued, - "successful": success, - "failed": report.n_samples_failed, - "elapsed_time": elapsed, + "successful": max(0, total_issued - n_errors), + "failed": n_errors, + "elapsed_time": perf_elapsed, "qps": qps, }, } @@ -473,5 +709,15 @@ def run_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> None: config.model_dump_json(indent=2, exclude_none=True), ) ctx = setup_benchmark(config, test_mode) - report, collector = run_benchmark_threaded(ctx) - finalize_benchmark(ctx, report, collector) + bench: BenchmarkResult | None = None + try: + bench = run_benchmark_async(ctx) + finalize_benchmark(ctx, bench) + except KeyboardInterrupt: + logger.warning("Benchmark interrupted by user") + finally: + if bench and bench.tmpfs_dir.exists(): + # Salvage logs from tmpfs before cleanup (no-op if finalize already copied) + _salvage_tmpfs(ctx.report_dir, bench.tmpfs_dir) + shutil.rmtree(bench.tmpfs_dir, ignore_errors=True) + logger.info(f"Partial results saved to {ctx.report_dir}") diff --git a/src/inference_endpoint/config/ruleset_registry.py b/src/inference_endpoint/config/ruleset_registry.py index 59cddc30..75ce5139 100644 --- a/src/inference_endpoint/config/ruleset_registry.py +++ b/src/inference_endpoint/config/ruleset_registry.py @@ -28,6 +28,8 @@ from typing import TYPE_CHECKING +from .rulesets.mlcommons.rules import CURRENT as mlcommons_current + if TYPE_CHECKING: from .ruleset_base import BenchmarkSuiteRuleset @@ -77,18 +79,10 @@ def list_rulesets() -> list[str]: # Auto-register MLCommons rulesets def _auto_register_mlcommons(): """Auto-register MLCommons rulesets.""" - try: - from .rulesets.mlcommons.rules import CURRENT as mlcommons_current - - # Register with version-specific name - register_ruleset( - f"mlperf-inference-{mlcommons_current.version}", mlcommons_current - ) - # Also register as "mlcommons-current" for convenience - register_ruleset("mlcommons-current", mlcommons_current) - except ImportError: - # MLCommons rulesets not available - pass + # Register with version-specific name + register_ruleset(f"mlperf-inference-{mlcommons_current.version}", mlcommons_current) + # Also register as "mlcommons-current" for convenience + register_ruleset("mlcommons-current", mlcommons_current) # Auto-register on import diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 01f7ce2c..bb7540e8 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -460,7 +460,10 @@ class EndpointConfig(BaseModel): endpoints: Annotated[ list[str], cyclopts.Parameter(alias="--endpoints", help="Endpoint URL(s)", negative=""), - ] = Field(min_length=1) + ] = Field( + min_length=1, + description="Endpoint URL(s). Must include scheme, e.g. 'http://host:port'.", + ) api_key: Annotated[ str | None, cyclopts.Parameter(alias="--api-key", help="API key") ] = None @@ -469,6 +472,16 @@ class EndpointConfig(BaseModel): cyclopts.Parameter(alias="--api-type", help="API type: openai or sglang"), ] = APIType.OPENAI + @field_validator("endpoints", mode="after") + @classmethod + def _validate_endpoint_scheme(cls, v: list[str]) -> list[str]: + for url in v: + if not url.startswith(("http://", "https://")): + raise ValueError( + f"Endpoint URL must include scheme (http:// or https://), got: {url!r}" + ) + return v + class BenchmarkConfig(WithUpdatesMixin, BaseModel): """Benchmark configuration — single source of truth for YAML and CLI. @@ -721,7 +734,7 @@ def create_default_config(test_type: TestType) -> BenchmarkConfig: _common = { "model_params": ModelParams(name=""), "datasets": [Dataset(path="")], - "endpoint_config": EndpointConfig(endpoints=[""]), + "endpoint_config": EndpointConfig(endpoints=["http://localhost:8000"]), } if test_type == TestType.OFFLINE: return OfflineBenchmarkConfig(**_common) diff --git a/src/inference_endpoint/config/utils.py b/src/inference_endpoint/config/utils.py index 9b50717d..fb091a7d 100644 --- a/src/inference_endpoint/config/utils.py +++ b/src/inference_endpoint/config/utils.py @@ -124,10 +124,8 @@ def parse_dataset_string(s: str) -> dict[str, object]: # Validate parser remap targets (CLI only — YAML validated in factory) if "parser" in result and isinstance(result["parser"], dict): - # Lazy import to avoid circular dep: schema_utils → dataset_manager → schema - from inference_endpoint.dataset_manager.transforms import ( - MakeAdapterCompatible, - ) + # Lazy import: circular dependency (config.schema → config.utils → dataset_manager → config.schema) + from inference_endpoint.dataset_manager.transforms import MakeAdapterCompatible valid = set(MakeAdapterCompatible().remap.values()) invalid = set(result["parser"].keys()) - valid diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index aa862b66..accd2ca8 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -348,19 +348,13 @@ class StreamChunk( display and accurate Time-To-First-Token (TTFT) measurements. Multiple StreamChunks with the same id collectively form the complete response. - The is_complete flag indicates the final chunk in the sequence. + The final QueryResult (sent by the worker after all chunks) signals completion. Attributes: id: Query identifier (matches the originating Query.id). response_chunk: Partial response text for this chunk (delta, not cumulative). - is_complete: True if this is the final chunk, False for intermediate chunks. metadata: Additional metadata for this chunk (timing, token info, etc.). - Example: - Streaming "Hello World" might produce: - >>> StreamChunk(id="q1", response_chunk="Hello", is_complete=False) - >>> StreamChunk(id="q1", response_chunk=" World", is_complete=True) - Note: gc=False: Safe because metadata contains only scalar key-value pairs. Do NOT store cyclic references in metadata field. @@ -368,13 +362,12 @@ class StreamChunk( omit_defaults=True: Fields with static defaults (ie. those NOT using default_factory) are omitted if value equals default. - array_like=True: Encodes as array instead of object (e.g. ["id", "chunk", false, {}] + array_like=True: Encodes as array instead of object (e.g. ["id", "chunk", {}] instead of {"id": ..., "response_chunk": ..., ...}). Reduces payload size. """ id: str = "" response_chunk: str = "" - is_complete: bool = False metadata: dict[str, Any] = msgspec.field(default_factory=dict) diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index fb6b5325..b0d6f94f 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -19,6 +19,7 @@ """ import logging +from pathlib import Path from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat @@ -102,8 +103,6 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data transforms.append(MakeAdapterCompatible()) assert dataset_path is not None - from pathlib import Path - return Dataset.load_from_file( Path(dataset_path), transforms=transforms, diff --git a/src/inference_endpoint/dataset_manager/predefined/livecodebench/__init__.py b/src/inference_endpoint/dataset_manager/predefined/livecodebench/__init__.py index be2a12af..1f1dd23c 100644 --- a/src/inference_endpoint/dataset_manager/predefined/livecodebench/__init__.py +++ b/src/inference_endpoint/dataset_manager/predefined/livecodebench/__init__.py @@ -58,7 +58,7 @@ def _ensure_venv(cls, venv_path: Path) -> Path: """ if not venv_path.exists(): logger.info(f"Creating virtual environment at {venv_path}") - venv.create(venv_path, with_pip=True, clear=True) + venv.create(venv_path, with_pip=True, clear=True, symlinks=False) # Determine Python executable path based on platform if sys.platform == "win32": diff --git a/src/inference_endpoint/dataset_manager/transforms.py b/src/inference_endpoint/dataset_manager/transforms.py index a2e2e3be..79133796 100644 --- a/src/inference_endpoint/dataset_manager/transforms.py +++ b/src/inference_endpoint/dataset_manager/transforms.py @@ -17,6 +17,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from importlib import import_module from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -24,6 +25,7 @@ import pandas as pd +from ..endpoint_client.config import ADAPTER_MAP from ..openai.harmony import Harmonizer @@ -405,10 +407,6 @@ def get_transforms_for_api_type( Returns: A list of transforms required for the given API type """ - from importlib import import_module - - from inference_endpoint.endpoint_client.config import ADAPTER_MAP - adapter_path = ADAPTER_MAP.get(api_type) if not adapter_path: raise ValueError(f"Invalid or unsupported API type: {api_type}") diff --git a/src/inference_endpoint/endpoint_client/http_client.py b/src/inference_endpoint/endpoint_client/http_client.py index 4d158d75..e273f581 100644 --- a/src/inference_endpoint/endpoint_client/http_client.py +++ b/src/inference_endpoint/endpoint_client/http_client.py @@ -68,7 +68,10 @@ def __init__( self.loop = loop assert self.loop is not None - # Initialize on event loop + # Initialize on event loop. + # NOTE: This uses run_coroutine_threadsafe().result() which DEADLOCKS + # if called from the same event loop thread. For shared-loop usage, + # use the async factory: await HTTPEndpointClient.create(config, loop) asyncio.run_coroutine_threadsafe(self._initialize(), self.loop).result() logger.info( @@ -79,6 +82,36 @@ def __init__( f"transport={self.config.transport.type if self.config.transport else 'none'}" ) + @classmethod + async def create( + cls, + config: HTTPClientConfig, + loop: asyncio.AbstractEventLoop, + ) -> "HTTPEndpointClient": + """Async factory for shared-loop usage. + + Use this instead of __init__ when the caller is already running on + the target event loop (e.g., inside run_benchmark_async). The regular + constructor uses run_coroutine_threadsafe().result() which deadlocks + when called from the same loop. + """ + self = cls.__new__(cls) + self.client_id = uuid.uuid4().hex[:8] + self.config = config + self._worker_cycle = cycle(range(config.num_workers)) + self._owns_loop = False + self._loop_name = None + self.loop = loop + await self._initialize() + logger.info( + f"EndpointClient initialized with num_workers={config.num_workers}, " + f"endpoints={config.endpoint_urls}, " + f"adapter={config.adapter.__name__ if config.adapter else 'none'}, " + f"accumulator={config.accumulator.__name__ if config.accumulator else 'none'}, " + f"transport={config.transport.type if config.transport else 'none'}" + ) + return self + async def _initialize(self) -> None: """Initialize worker manager and transports.""" self._shutdown: bool = False @@ -113,11 +146,22 @@ def drain(self) -> list[QueryResult | StreamChunk]: return list(iter(self.poll, None)) def shutdown(self) -> None: - """Gracefully shutdown client. Synchronous — blocks the caller until complete.""" - if self._shutdown: # Already shutdown, no-op + """Gracefully shutdown client. Synchronous — blocks the caller until complete. + + NOTE: This uses run_coroutine_threadsafe().result() which DEADLOCKS + if called from the same event loop thread. For shared-loop usage, + use: await client.shutdown_async() + """ + if self._shutdown: return asyncio.run_coroutine_threadsafe(self._shutdown_async(), self.loop).result() + async def shutdown_async(self) -> None: + """Async shutdown for shared-loop usage. Must be called from the event loop.""" + if self._shutdown: + return + await self._shutdown_async() + async def _shutdown_async(self) -> None: """Async shutdown internals - must be called on the event loop.""" self._shutdown = True diff --git a/src/inference_endpoint/endpoint_client/http_sample_issuer.py b/src/inference_endpoint/endpoint_client/http_sample_issuer.py index c30379ee..38f38923 100644 --- a/src/inference_endpoint/endpoint_client/http_sample_issuer.py +++ b/src/inference_endpoint/endpoint_client/http_sample_issuer.py @@ -13,90 +13,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LoadGenerator integration for HTTPEndpointClient.""" +"""SampleIssuer implementation wrapping HTTPEndpointClient. -import asyncio -import logging +Thin adapter: delegates issue/recv/shutdown to the underlying HTTP client. +The BenchmarkSession owns the response receive loop — this class does NOT +run its own _handle_responses coroutine. +""" from inference_endpoint.core.types import Query, QueryResult, StreamChunk from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient -from inference_endpoint.load_generator import SampleIssuer -from inference_endpoint.load_generator.sample import Sample, SampleEventHandler -from inference_endpoint.profiling import profile -logger = logging.getLogger(__name__) +class HttpClientSampleIssuer: + """SampleIssuer wrapping an HTTPEndpointClient. -class HttpClientSampleIssuer(SampleIssuer): - """ - SampleIssuer interface for HTTPEndpointClient. - Routes completed responses to SampleEventHandler. + Satisfies the SampleIssuer protocol from load_generator.session. Usage: - # Create HTTP client and sample issuer - auto-initializes - client = HTTPEndpointClient(config) + client = await HTTPEndpointClient.create(config, loop) issuer = HttpClientSampleIssuer(client) - # Issue samples - issuer.issue(sample) - - # shutdown() is optional - only needed for early exit + issuer.issue(query) # sync ZMQ push + response = await issuer.recv() # async ZMQ recv + issuer.shutdown() # no-op (client shutdown called separately) """ - def __init__( - self, - http_client: HTTPEndpointClient, - ): - super().__init__() + def __init__(self, http_client: HTTPEndpointClient): self.http_client = http_client - # Start response handler task to route completed responses back to SampleEventHandler - self._response_task = asyncio.run_coroutine_threadsafe( - self._handle_responses(), self.http_client.loop - ) - - @profile - async def _handle_responses(self): - """Route completed responses to SampleEventHandler.""" - while True: - try: - # TODO(vir): consider using recv() + drain - match response := await self.http_client.recv(): - case StreamChunk(is_complete=False): - # NOTE(vir): is_complete=True should not be received, QueryResult is expected instead - SampleEventHandler.stream_chunk_complete(response) - - case QueryResult(error=err): - SampleEventHandler.query_result_complete(response) - if err is not None: - logger.error(f"Error in request {response.id}: {err}") - - case None: - # Transport closed or shutdown - break - - case _: - raise ValueError(f"Unexpected response type: {type(response)}") - - except asyncio.CancelledError: - # Handle shutdown signal - break - except Exception as e: - logger.error(f"Error in response handler: {e}", exc_info=True) - continue + def issue(self, query: Query) -> None: + """Issue query to HTTP endpoint. Non-blocking (ZMQ push).""" + self.http_client.issue(query) - @profile - def issue(self, sample: Sample): - """Issue sample to HTTP endpoint.""" - # NOTE(vir): - # If using extra headers (e.g., Authorization), pre-cache them in - # worker.py request-template via HttpRequestTemplate.cache_headers() - # to avoid per-request encoding overhead at runtime. - self.http_client.issue(Query(id=sample.uuid, data=sample.data)) + async def recv(self) -> QueryResult | StreamChunk | None: + """Wait for next response. Returns None when transport is closed.""" + return await self.http_client.recv() - def shutdown(self): - """ - Gracefully shutdown sample issuer. - Will cancel the response-handler task. - """ - self._response_task.cancel() + def shutdown(self) -> None: + """No-op. HTTPEndpointClient.shutdown() is called separately by the caller.""" + pass diff --git a/src/inference_endpoint/evaluation/scoring.py b/src/inference_endpoint/evaluation/scoring.py index 5766e713..f5145a75 100644 --- a/src/inference_endpoint/evaluation/scoring.py +++ b/src/inference_endpoint/evaluation/scoring.py @@ -36,9 +36,16 @@ except ImportError: websocket = None +try: + import evaluate as _evaluate + import nltk as _nltk +except ImportError: + _evaluate = None + _nltk = None + +from ..core.record import EventRecord, EventType, SampleEventType from ..dataset_manager.dataset import Dataset from ..dataset_manager.predefined.shopify_product_catalogue import ProductMetadata -from ..load_generator.sample import SampleEvent from .extractor import Extractor, PythonCodeExtractor @@ -100,10 +107,6 @@ def __init__( self.dataset = dataset self.report_dir = Path(report_dir) self.extractor = extractor - # If the dataset was transformed with a preset, we still treat it as the original - # dataset name for the purposes of scoring - if "::" in dataset_name: - dataset_name = dataset_name.split("::")[0] self.dataset_name = dataset_name self.ground_truth_column = ( @@ -123,22 +126,30 @@ def _load_sample_index_map(self): return d[self.dataset_name] # Implicitly raises KeyError def get_outputs(self): - # TODO: Currently, the outputs are only saved in the events.jsonl file, which is quite - # large, and only saved optionally. Later, we should move to saving the outputs in a - # separate file for easier compute. + """Read COMPLETE events from events.jsonl and extract response text. + + The EventLoggerService writes EventRecord objects serialized via msgspec. + We decode them using the EventRecord decoder and extract the response + text from TextModelOutput data. + """ events_log_path = self.report_dir / "events.jsonl" if not events_log_path.exists(): raise FileNotFoundError(f"Events log file not found at {events_log_path}") - outputs = [] + decoder = msgspec.json.Decoder(type=EventRecord, dec_hook=EventType.decode_hook) + outputs: list[dict[str, str]] = [] with events_log_path.open("r") as f: for line in f: - event = msgspec.json.decode(line.strip()) - if event["event_type"] == SampleEvent.COMPLETE.value: - outputs.append(event) - df = pd.DataFrame(outputs, columns=["sample_uuid", "value"]) - df.rename(columns={"value": "output"}, inplace=True) - return df + stripped = line.strip() + if not stripped: + continue + record = decoder.decode(stripped) + if record.event_type == SampleEventType.COMPLETE: + output_text = str(record.data) if record.data is not None else "" + outputs.append( + {"sample_uuid": record.sample_uuid, "output": output_text} + ) + return pd.DataFrame(outputs) def match_sample_index(self, row: pd.Series) -> pd.Series: # Pandas Apply function to create a new 'sample_index' column @@ -226,27 +237,13 @@ class RougeScorer(Scorer, scorer_id="rouge"): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - try: - import importlib.util as _importlib_util - - if ( - _importlib_util.find_spec("evaluate") is None - or _importlib_util.find_spec("nltk") is None - or _importlib_util.find_spec("rouge_score") is None - ): - raise ImportError - - import evaluate - import nltk - - self.metric = evaluate.load("rouge") - self.nltk = nltk - - except ImportError: + if _evaluate is None or _nltk is None: raise ImportError( "nltk, evaluate, and rouge_score are required for ROUGE scoring. " "Install with: pip install nltk evaluate rouge_score" - ) from None + ) + self.metric = _evaluate.load("rouge") + self.nltk = _nltk def postprocess_text(self, texts): texts = [text.strip() for text in texts] diff --git a/src/inference_endpoint/load_generator/__init__.py b/src/inference_endpoint/load_generator/__init__.py index c032b4c6..6f44d958 100644 --- a/src/inference_endpoint/load_generator/__init__.py +++ b/src/inference_endpoint/load_generator/__init__.py @@ -13,40 +13,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Load Generator for the MLPerf Inference Endpoint Benchmarking System. +"""Async load generator for the MLPerf Inference Endpoint Benchmarking System. -This module handles load pattern generation and query lifecycle management. -Status: To be implemented by the development team. +See docs/load_generator/design.md for the full design. """ -from .load_generator import LoadGenerator, SampleIssuer, SchedulerBasedLoadGenerator -from .sample import IssuedSample, Sample, SampleEvent, SampleEventHandler -from .scheduler import ( - ConcurrencyScheduler, - MaxThroughputScheduler, - PoissonDistributionScheduler, +from .delay import make_delay_fn, poisson_delay_fn +from .sample_order import ( SampleOrder, - Scheduler, WithoutReplacementSampleOrder, WithReplacementSampleOrder, + create_sample_order, +) +from .session import ( + BenchmarkSession, + PhaseConfig, + PhaseIssuer, + PhaseResult, + PhaseType, + SessionResult, +) +from .strategy import ( + BurstStrategy, + ConcurrencyStrategy, + LoadStrategy, + TimedIssueStrategy, + create_load_strategy, ) -from .session import BenchmarkSession __all__ = [ - "SampleEvent", - "Sample", - "SampleEventHandler", - "IssuedSample", - "Scheduler", - "ConcurrencyScheduler", - "MaxThroughputScheduler", - "PoissonDistributionScheduler", + # New async API + "BenchmarkSession", + "PhaseConfig", + "PhaseType", + "PhaseResult", + "SessionResult", + "PhaseIssuer", + "LoadStrategy", + "TimedIssueStrategy", + "BurstStrategy", + "ConcurrencyStrategy", + "create_load_strategy", "SampleOrder", - "WithReplacementSampleOrder", "WithoutReplacementSampleOrder", - "LoadGenerator", - "SampleIssuer", - "SchedulerBasedLoadGenerator", - "BenchmarkSession", + "WithReplacementSampleOrder", + "create_sample_order", + "make_delay_fn", + "poisson_delay_fn", ] diff --git a/src/inference_endpoint/load_generator/delay.py b/src/inference_endpoint/load_generator/delay.py new file mode 100644 index 00000000..bfa09a25 --- /dev/null +++ b/src/inference_endpoint/load_generator/delay.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inter-arrival delay functions for timed load strategies. + +Each function returns a callable that produces delay values in nanoseconds. +Used by TimedIssueStrategy for Poisson and other time-based load patterns. +""" + +from __future__ import annotations + +import random +from collections.abc import Callable + +from ..config.schema import LoadPattern, LoadPatternType + + +def poisson_delay_fn(target_qps: float, rng: random.Random) -> Callable[[], int]: + """Create a Poisson-distributed delay function. + + Returns inter-arrival delays following an exponential distribution + (Poisson process). Models realistic client behavior where requests + arrive independently at a target rate. + + How it works: + + ``expovariate(lambd)`` draws from the exponential distribution with rate + ``lambd``. Critically, the return value is in units of ``1 / lambd`` — + NOT in units of ``lambd``. So if ``lambd`` is expressed in + events-per-nanosecond, the return value is in nanoseconds. + + Step by step for target_qps = 50,000: + 1. lambd = 50,000 / 1e9 = 5e-5 events per nanosecond + 2. expovariate(5e-5) returns values with mean = 1 / 5e-5 = 20,000 ns + 3. So the average inter-arrival delay is 20,000 ns = 20 us + 4. This matches 50,000 QPS: 1 second / 20 us = 50,000 queries + + The return value is cast to int (nanoseconds). The ``max(1, ...)`` guard + prevents zero-delay at extreme QPS (> 500M), where the mean approaches + 1 ns and the exponential distribution produces sub-1 values ~63% of the + time. In practice, no system can issue > 500M QPS, so the guard is + purely defensive. + + Reference: https://docs.python.org/3/library/random.html#random.Random.expovariate + + Args: + target_qps: Target queries per second. + rng: Seeded random number generator for reproducibility. + + Returns: + Callable returning delay in nanoseconds (int, always >= 1). + """ + if target_qps <= 0: + raise ValueError(f"target_qps must be > 0, got {target_qps}") + lambd = target_qps / 1_000_000_000 # events per nanosecond + return lambda: max(1, int(rng.expovariate(lambd))) + + +def make_delay_fn(load_pattern: LoadPattern, rng: random.Random) -> Callable[[], int]: + """Create a delay function from a LoadPattern config. + + Only used by TimedIssueStrategy. MAX_THROUGHPUT uses BurstStrategy, + CONCURRENCY uses ConcurrencyStrategy — neither needs a delay function. + + Args: + load_pattern: LoadPattern config from schema.py. + rng: Seeded random number generator for reproducibility. + + Returns: + Callable returning delay in nanoseconds. + + Raises: + ValueError: If load pattern type has no delay function. + """ + if load_pattern.type == LoadPatternType.POISSON: + if load_pattern.target_qps is None or load_pattern.target_qps <= 0: + raise ValueError("Poisson load pattern requires target_qps > 0") + return poisson_delay_fn(load_pattern.target_qps, rng) + + raise ValueError( + f"No delay function for load pattern type: {load_pattern.type}. " + f"MAX_THROUGHPUT uses BurstStrategy, CONCURRENCY uses ConcurrencyStrategy." + ) diff --git a/src/inference_endpoint/load_generator/load_generator.py b/src/inference_endpoint/load_generator/load_generator.py deleted file mode 100644 index 3081e95a..00000000 --- a/src/inference_endpoint/load_generator/load_generator.py +++ /dev/null @@ -1,291 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from abc import ABC, abstractmethod -from typing import Any - -from ..dataset_manager.dataset import Dataset -from ..utils import sleep_ns -from .sample import IssuedSample, Sample -from .scheduler import Scheduler - - -class SampleIssuer(ABC): - """Abstract base class for components that send samples to inference endpoints. - - SampleIssuers are responsible for the complete workflow of sending a sample - to a System Under Test (SUT): - 1. Ingest a Sample object from the Load Generator - 2. Build the appropriate request format (HTTP, gRPC, etc.) - 3. Send the request to the endpoint - 4. Handle the response asynchronously (results arrive via callbacks) - - Implementations must handle: - - Request formatting (converting Sample.data to endpoint-specific format) - - Network communication (HTTP, gRPC, WebSocket, etc.) - - Error handling (timeouts, connection errors, etc.) - - Response routing (back to metrics collector via events) - - Lifecycle: - 1. start() - Initialize connections, setup resources - 2. issue(sample) - Send samples (called repeatedly during benchmark) - 3. shutdown() - Clean up connections, release resources - - Example implementations: - - HttpClientSampleIssuer: HTTP/REST endpoints (OpenAI-compatible) - - GrpcSampleIssuer: gRPC endpoints (future) - """ - - def start(self): # noqa: B027 - """Initialize resources and establish connections. - - Called once after instantiation to set up any dependency components - like HTTP client pools, authentication, or connection pooling. - - Optional implementation - default does nothing. - - Raises: - SetupError: If initialization fails. - """ - pass - - @abstractmethod - def issue(self, sample: Sample): - """Send a sample to the SUT endpoint. - - This is the core method that sends a single sample/query to the endpoint. - It should be non-blocking and return quickly - actual response handling - happens asynchronously via the event system. - - The implementation must: - 1. Convert Sample.data to the endpoint's request format - 2. Send the request (typically async/non-blocking) - 3. Ensure response triggers appropriate events (COMPLETE, STREAM_CHUNK, etc.) - - Args: - sample: Sample object containing request data and metadata. - - Raises: - ExecutionError: If request cannot be sent. - """ - raise NotImplementedError - - def shutdown(self): # noqa: B027 - """Clean up resources and close connections. - - Called once when the issuer is no longer needed. Should gracefully - shutdown connections, flush pending requests, and release resources. - - Optional implementation - default does nothing. - """ - pass - - -class LoadGenerator(ABC): - """Abstract base class for load generation strategies. - - LoadGenerators control WHEN samples are issued to the SUT. They coordinate: - - Sample selection from the dataset (via DataLoader) - - Timing and scheduling (via Scheduler) - - Actual sample issuance (via SampleIssuer) - - Key responsibilities: - - Load sample data from dataset at the right time - - Apply scheduling/timing delays - - Issue samples via the SampleIssuer - - LoadGenerators are iterators - each iteration issues one sample and - returns information about what was issued. - - Attributes: - sample_issuer: Component that sends samples to endpoints. - dataloader: Component that loads sample data from datasets. - """ - - def __init__( - self, - sample_issuer: SampleIssuer, - dataloader: Dataset, - name: str | None = None, - ): - """Initialize load generator with required dependencies. - - Args: - sample_issuer: SampleIssuer to send samples to endpoint. - dataloader: DataLoader to retrieve sample data from dataset. - """ - self.sample_issuer = sample_issuer - self.dataloader = dataloader - self.name = name - self.uuid_to_index_map: dict[str, int] = {} - - @abstractmethod - def __next__(self) -> IssuedSample: - """Issue the next sample according to the load generation strategy. - - This method should: - 1. Determine which sample to issue next - 2. Load the sample data from dataloader - 3. Apply any scheduling delays (blocking) - 4. Issue the sample via sample_issuer - 5. Return the sample and timestamp - - Note: This method MAY block to implement delays/scheduling. - It should only return AFTER the sample has been issued. - - Returns: - IssuedSample object containing the sample, index, and issue timestamp. - - Raises: - StopIteration: When all samples have been issued. - """ - raise NotImplementedError - - def __iter__(self): - """Return self as an iterator.""" - self.uuid_to_index_map = {} - return self - - def load_sample_data(self, sample_index: int, sample_uuid: str) -> Any: - """Load sample data from dataloader. - - Args: - sample_index: Index of sample in dataset. - sample_uuid: UUID of the sample being created. - - Returns: - Sample data loaded from dataloader (format depends on dataset). - """ - return self.dataloader.load_sample(sample_index) - - def issue_sample(self, sample: Sample) -> int: - """Issue a sample via the SampleIssuer. - - Records the current timestamp, issues the sample, and returns the timestamp. - - Args: - sample: Sample to issue to the endpoint. - - Returns: - Monotonic nanosecond timestamp when issue was called. - """ - timestamp_ns = time.monotonic_ns() - logging.debug(f"Issuing sample {sample.uuid} at {timestamp_ns}") - self.sample_issuer.issue(sample) - return timestamp_ns - - -class SchedulerBasedLoadGenerator(LoadGenerator): - """LoadGenerator that uses a Scheduler to control sample timing. - - This is the primary LoadGenerator implementation, delegating timing decisions - to a pluggable Scheduler. It handles: - - Sample ordering (via scheduler's sample_order) - - Timing delays (via scheduler's delay_fn) - - Sample loading and issuance - - Timing measurements - - The scheduler determines: - - Which sample to issue next (sample index) - - How long to wait before issuing (delay in nanoseconds) - - This enables different load patterns (Poisson, max throughput, burst, etc.) - without changing the LoadGenerator code. - - Attributes: - scheduler: Scheduler controlling sample timing. - _iterator: Iterator over scheduler (sample_index, delay) pairs. - last_issue_timestamp_ns: Timestamp of last issued sample (for delay calculation). - """ - - def __init__( - self, - sample_issuer: SampleIssuer, - dataloader: Dataset, - scheduler: Scheduler, - ): - """Initialize scheduler-based load generator. - - Args: - sample_issuer: SampleIssuer to send samples to endpoint. - dataloader: DataLoader to retrieve sample data. - scheduler: Scheduler controlling timing and sample order. - """ - super().__init__(sample_issuer, dataloader) - - self.scheduler = scheduler - self._iterator = None - self.last_issue_timestamp_ns = 0 - self._start_time_ns: int | None = None - - def __next__(self) -> IssuedSample: - """Issue next sample according to scheduler timing. - - This method: - 1. Gets next (sample_index, delay_ns) from scheduler - 2. Loads sample data from dataloader - 3. Waits for scheduled time (busy-wait for precision) - 4. Issues sample via sample_issuer - 5. Returns IssuedSample with timing info - - The busy-wait ensures precise timing even for high QPS scenarios - where sleep() precision would be insufficient. - - Returns: - IssuedSample containing sample, index, and actual issue timestamp. - - Raises: - StopIteration: When scheduler has no more samples to issue. - """ - # Check wall-clock timeout before advancing the iterator, so we don't - # consume a (sample_index, delay) pair that will never be issued. - max_duration_ms = self.scheduler.runtime_settings.max_duration_ms - if max_duration_ms is not None and self._start_time_ns is not None: - elapsed_ns = time.monotonic_ns() - self._start_time_ns - if elapsed_ns >= max_duration_ms * 1_000_000: - logging.info( - f"max_duration_ms={max_duration_ms}ms reached after " - f"{elapsed_ns / 1e6:.1f}ms, stopping sample issuance" - ) - raise StopIteration - - # Let raised StopIteration be propagated up the stack - # Ignore mypy error complaining that self._iterator maybe None - s_idx, delay_ns = next(self._iterator) # type: ignore[call-overload] - - # Data loading is not timed for Time-to-Token metrics. It is assumed that the - # hypothetical user would have put the data into memory available for a network - # request beforehand. - sample = Sample(None) # Create sample object first to generate uuid - sample.data = self.load_sample_data(s_idx, sample.uuid) - - self.uuid_to_index_map[sample.uuid] = s_idx - - scheduled_issue_timestamp_ns = self.last_issue_timestamp_ns + delay_ns - while (now := time.monotonic_ns()) < scheduled_issue_timestamp_ns: - sleep_ns(scheduled_issue_timestamp_ns - now) - self.last_issue_timestamp_ns = self.issue_sample(sample) - return IssuedSample(sample, s_idx, self.last_issue_timestamp_ns) - - def __iter__(self): - if self._iterator is not None: - raise RuntimeError( - "SchedulerBasedLoadGenerator can only be iterated over once" - ) - self._start_time_ns = time.monotonic_ns() - self._iterator = iter(self.scheduler) - return super().__iter__() diff --git a/src/inference_endpoint/load_generator/sample.py b/src/inference_endpoint/load_generator/sample.py deleted file mode 100644 index 28516a8d..00000000 --- a/src/inference_endpoint/load_generator/sample.py +++ /dev/null @@ -1,201 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import uuid -from collections.abc import Callable -from dataclasses import dataclass -from enum import Enum -from typing import Any - -from ..core.types import QueryResult, StreamChunk - -logger = logging.getLogger(__name__) - - -class SampleEvent(str, Enum): - """Event types for sample lifecycle hooks. - - Used by _SampleEventHandler for hook registration. These are local to the - hook system and separate from core.record.SampleEventType which is used - for the event publishing pipeline. - """ - - COMPLETE = "complete" - FIRST_CHUNK = "first_chunk_received" - NON_FIRST_CHUNK = "non_first_chunk_received" - - -class Sample: - """Represents a sample/query to be sent to an inference endpoint. - - A Sample encapsulates the request data and provides a unique identifier for - tracking through the benchmark lifecycle. It enforces immutability to prevent - accidental modification during benchmarking. - - Immutability rules: - - UUID is immutable once set (on creation) - - Data can be set once from None to a value, then immutable - - This allows delayed data loading while maintaining safety - - Memory optimization: - - Uses __slots__ to reduce memory overhead - - UUID as hex string (32 chars) instead of UUID object - - Attributes: - uuid: Unique hex string identifier for this sample (32 characters). - data: Request payload (dict, typically with prompt/model/params). - Can be None initially and set once. - - Example: - >>> sample = Sample({"prompt": "Hello", "model": "gpt-4"}) - >>> sample.uuid # '8f3d2a1b...' (32 char hex) - >>> sample.data["prompt"] # 'Hello' - """ - - __slots__ = ["uuid", "data"] - - def __init__(self, data: Any): - """Initialize sample with data and generate unique ID. - - Args: - data: Request data to send to endpoint. Can be None if data - will be loaded later, but can only be set once. - """ - # 128-bit UUID might be a little overkill for our use case, we can investigate slimming down memory usage - self.uuid = uuid.uuid4().hex - self.data = data - - def __setattr__(self, name: str, value: Any): - if not hasattr(self, name) or (name == "data" and self.data is None): - object.__setattr__(self, name, value) - else: - raise AttributeError(f"Sample is immutable - cannot set attribute: {name}") - - -class _SampleEventHandler: - """Contains handlers for SampleEvents given a sample UUID. This is also to avoid needing other classes - to do their own bookkeeping for Sample objects, which can be discarded once they are issued, as long as - their UUIDs are saved. - - This class is a singleton rather than a class method mainly because it needs to hold some state (i.e. hooks) - - A user can register hooks to any event type, and will be run in the order they were registered. - A valid hook is a callable that takes a single argument, representing the response object (StreamChunk or QueryResult). - - A simple example use-case of a hook is to update a progress bar on-completion of a sample. - - NOTE: Hook lists are not thread-safe. Hooks must be registered before the benchmark - starts (single-threaded setup phase). This is a known limitation; _SampleEventHandler - is being deprecated in favor of the pub-sub EventLoggerService. - """ - - __slots__ = ["first_chunk_hooks", "non_first_chunk_hooks", "complete_hooks"] - - SINGLETON = None - _initialized = False - - def __new__(cls): - if cls.SINGLETON is None: - cls.SINGLETON = super().__new__(cls) - return cls.SINGLETON - - def __init__(self): - if _SampleEventHandler._initialized: - return - _SampleEventHandler._initialized = True - - self.first_chunk_hooks = [] - self.non_first_chunk_hooks = [] - self.complete_hooks = [] - - def register_hook( - self, - event_type: SampleEvent, - hook: Callable[[StreamChunk], None] | Callable[[QueryResult], None], - ) -> None: - if event_type == SampleEvent.FIRST_CHUNK: - self.first_chunk_hooks.append(hook) - elif event_type == SampleEvent.NON_FIRST_CHUNK: - self.non_first_chunk_hooks.append(hook) - elif event_type == SampleEvent.COMPLETE: - self.complete_hooks.append(hook) - else: - raise ValueError(f"Invalid event type: {event_type}") - - def clear_hooks(self, ev_type: SampleEvent | None = None) -> None: - if ev_type is None: - self.first_chunk_hooks.clear() - self.non_first_chunk_hooks.clear() - self.complete_hooks.clear() - elif ev_type == SampleEvent.FIRST_CHUNK: - self.first_chunk_hooks.clear() - elif ev_type == SampleEvent.NON_FIRST_CHUNK: - self.non_first_chunk_hooks.clear() - elif ev_type == SampleEvent.COMPLETE: - self.complete_hooks.clear() - - def stream_chunk_complete(self, chunk: StreamChunk) -> None: - """Handle completion of a streaming chunk. - - Called when a chunk arrives from a streaming response. Invokes - registered hooks for first/non-first chunks. - - Args: - chunk: StreamChunk containing response data and metadata. - """ - assert isinstance(chunk, StreamChunk), f"Invalid chunk type: {type(chunk)}" - - if chunk.metadata.get("first_chunk", False): - hooks = self.first_chunk_hooks - else: - hooks = self.non_first_chunk_hooks - - for hook in hooks: - hook(chunk) - - def query_result_complete(self, result: QueryResult) -> None: - """Handle completion of a query (success or failure). - - Called when a query finishes (with response or error). Invokes - registered completion hooks. - - Args: - result: QueryResult containing response data or error information. - """ - assert isinstance(result, QueryResult), f"Invalid result type: {type(result)}" - - if result.error is not None: - logger.error(f"Error in request {result.id}: {result.error}") - - for hook in self.complete_hooks: - hook(result) - - -@dataclass -class IssuedSample: - """Contains data about a sample that has been issued to the inference endpoint. - - SampleIssuer is not allowed to know the actual sample index of the data to prevent cheating - and response caching. This class contains metadata about the sample for bookkeeping by the - LoadGenerator and BenchmarkSession. - """ - - sample: Sample - index: int - issue_timestamp_ns: int - - -SampleEventHandler = _SampleEventHandler() diff --git a/src/inference_endpoint/load_generator/sample_order.py b/src/inference_endpoint/load_generator/sample_order.py new file mode 100644 index 00000000..b798df5b --- /dev/null +++ b/src/inference_endpoint/load_generator/sample_order.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample ordering strategies for benchmark dataset traversal. + +SampleOrder is an infinite iterator yielding dataset indices. It never raises +StopIteration — termination is controlled by BenchmarkSession._should_stop(). +""" + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod + +from ..config.runtime_settings import RuntimeSettings + + +class SampleOrder(ABC): + """Abstract base class for sample ordering strategies. + + Yields dataset sample indices indefinitely. Different strategies enable + different testing scenarios (balanced coverage vs random sampling). + + Attributes: + n_samples_in_dataset: Number of unique samples available in dataset. + rng: Random number generator for reproducible randomness. + """ + + def __init__( + self, + n_samples_in_dataset: int, + rng: random.Random = random, # type: ignore[assignment] + ): + if n_samples_in_dataset <= 0: + raise ValueError( + f"n_samples_in_dataset must be > 0, got {n_samples_in_dataset}" + ) + self.n_samples_in_dataset = n_samples_in_dataset + self.rng = rng + + def __iter__(self): + return self + + def __next__(self) -> int: + return self.next_sample_index() + + @abstractmethod + def next_sample_index(self) -> int: + """Get the next sample index to issue. + + Returns: + Sample index (0 to n_samples_in_dataset-1). + """ + raise NotImplementedError + + +class WithoutReplacementSampleOrder(SampleOrder): + """Shuffle dataset, use all samples before repeating. + + Ensures balanced coverage: shuffles all dataset indices, issues them one + by one until exhausted, then reshuffles and repeats (infinite cycle). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.index_order = list(range(self.n_samples_in_dataset)) + # Force initial shuffle on first call + self._curr_idx = self.n_samples_in_dataset + 1 + + def _reset(self): + self.rng.shuffle(self.index_order) + self._curr_idx = 0 + + def next_sample_index(self) -> int: + if self._curr_idx >= len(self.index_order): + self._reset() + retval = self.index_order[self._curr_idx] + self._curr_idx += 1 + return retval + + +class WithReplacementSampleOrder(SampleOrder): + """Truly random sampling from dataset with replacement. + + Each sample is chosen uniformly at random, independent of previous choices. + """ + + def next_sample_index(self) -> int: + return self.rng.randint(0, self.n_samples_in_dataset - 1) + + +def create_sample_order(settings: RuntimeSettings) -> SampleOrder: + """Create a SampleOrder from RuntimeSettings.""" + return WithoutReplacementSampleOrder( + n_samples_in_dataset=settings.n_samples_from_dataset, + rng=settings.rng_sample_index, + ) diff --git a/src/inference_endpoint/load_generator/scheduler.py b/src/inference_endpoint/load_generator/scheduler.py deleted file mode 100644 index ae691d09..00000000 --- a/src/inference_endpoint/load_generator/scheduler.py +++ /dev/null @@ -1,420 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import threading -from abc import ABC, abstractmethod -from collections.abc import Callable, Iterator - -from ..config.runtime_settings import RuntimeSettings -from ..config.schema import LoadPatternType -from .sample import SampleEvent, SampleEventHandler - - -class SampleOrder(ABC): - """Abstract base class for sample ordering strategies. - - SampleOrder determines which dataset sample to use next when issuing queries. - Different strategies enable different testing scenarios: - - The SampleOrder is an iterator that yields sample indices from the dataset. - It handles wrapping around when total_samples_to_issue > dataset size. - - Attributes: - total_samples_to_issue: Total number of samples to issue during benchmark. - n_samples_in_dataset: Number of unique samples available in dataset. - rng: Random number generator for reproducible randomness. - _issued_samples: Counter of samples issued so far. - """ - - def __init__( - self, total_samples_to_issue: int, n_samples_in_dataset: int, rng=random - ): - """Initialize sample ordering strategy. - - Args: - total_samples_to_issue: The total number of samples to issue. - May be larger than n_samples_in_dataset. - n_samples_in_dataset: The number of unique samples in the dataset. - rng: Random number generator (for reproducibility via seeding). - """ - self.total_samples_to_issue = total_samples_to_issue - self.n_samples_in_dataset = n_samples_in_dataset - self.rng = rng - - self._issued_samples = 0 - - def __iter__(self) -> Iterator[int]: - """Iterate over sample indices to issue. - - Yields sample indices until total_samples_to_issue is reached. - - Yields: - Sample index (0 to n_samples_in_dataset-1). - """ - while self._issued_samples < self.total_samples_to_issue: - yield self.next_sample_index() - self._issued_samples += 1 - - @abstractmethod - def next_sample_index(self) -> int: - """Get the next sample index to issue. - - Returns: - Sample index (0 to n_samples_in_dataset-1). - """ - raise NotImplementedError - - -class WithoutReplacementSampleOrder(SampleOrder): - """Sample ordering without replacement - shuffle dataset, use all samples before repeating. - - This strategy ensures balanced coverage of the dataset: - 1. Shuffles all dataset indices randomly - 2. Issues them one by one until exhausted - 3. Reshuffles and repeats if more samples needed - - Use this for: - - Fair benchmarking (all samples used equally) - - Avoiding bias from repeated samples - - Deterministic results with seed control - - Example with 3-sample dataset, 7 samples to issue: - - Shuffle: [2, 0, 1] - - Issue: 2, 0, 1 (first pass) - - Reshuffle: [1, 2, 0] - - Issue: 1, 2, 0, 1 (second pass, partial) - - Attributes: - index_order: Current shuffled order of indices. - _curr_idx: Position in current shuffle (resets after each complete pass). - """ - - def __init__(self, *args, **kwargs): - """Initialize without-replacement sample ordering. - - Args: - *args: Forwarded to SampleOrder.__init__. - **kwargs: Forwarded to SampleOrder.__init__. - """ - super().__init__(*args, **kwargs) - self.index_order = list(range(self.n_samples_in_dataset)) - self._curr_idx = ( - self.n_samples_in_dataset + 1 - ) # Ensure we start at an invalid index to force shuffle - - def _reset(self): - """Shuffle indices and reset position for next pass.""" - self.rng.shuffle(self.index_order) - self._curr_idx = 0 - - def next_sample_index(self) -> int: - """Get next sample index from current shuffle, reshuffling if needed. - - Returns: - Sample index from dataset. - """ - if self._curr_idx >= len(self.index_order): - self._reset() - retval = self.index_order[self._curr_idx] - self._curr_idx += 1 - return retval - - -class WithReplacementSampleOrder(SampleOrder): - """Sample ordering with replacement - truly random sampling from dataset. - - Each sample is chosen uniformly at random from the entire dataset, - independent of previous choices. The same sample can (and will) appear - multiple times, even consecutively. - - Use this for: - - Stress testing with realistic randomness - - Simulating unpredictable user behavior - - When dataset coverage balance is not important - - Example with 3-sample dataset, 7 samples to issue: - - Might produce: [1, 1, 0, 2, 1, 0, 0] - - Note repeated samples even without exhausting dataset - """ - - def __init__(self, *args, **kwargs): - """Initialize with-replacement sample ordering. - - Args: - *args: Forwarded to SampleOrder.__init__. - **kwargs: Forwarded to SampleOrder.__init__. - """ - super().__init__(*args, **kwargs) - - def next_sample_index(self) -> int: - """Get random sample index from dataset. - - Returns: - Random sample index (uniform distribution over dataset). - """ - return self.rng.randint(0, self.n_samples_in_dataset - 1) - - -def uniform_delay_fn( - max_delay_ns: int = 0, rng: random.Random | None = None -) -> Callable[[], float]: - """Create a uniform delay function for schedulers. - - Returns a function that generates delays uniformly distributed between - 0 and max_delay_ns. Used for max throughput (max_delay_ns=0) or uniform - load distribution. - - Args: - max_delay_ns: Maximum delay in nanoseconds. If 0, always returns 0 (no delay). - rng: Random number generator for reproducibility. - - Returns: - Function that returns delay in nanoseconds (float). - """ - rng = rng or random.Random() - - def _fn(): - if max_delay_ns == 0: - return 0 - return rng.uniform(0, max_delay_ns) - - return _fn - - -def poisson_delay_fn( - expected_queries_per_second: float, rng: random.Random | None = None -) -> Callable[[], float]: - """Create a Poisson-distributed delay function for realistic online benchmarking. - - Returns a function that generates delays following an exponential distribution - (inter-arrival times of a Poisson process). This models realistic user/client - behavior where requests arrive independently at a target rate. - - The exponential distribution has the property that: - - Mean inter-arrival time = 1 / expected_qps - - Variance = mean^2 (high variability, realistic for network traffic) - - Args: - expected_queries_per_second: Target QPS (queries per second). - rng: Random number generator for reproducibility. - - Returns: - Function that returns delay in nanoseconds (float). - """ - rng = rng or random.Random() - queries_per_ns = expected_queries_per_second / 1e9 - - def _fn(): - if queries_per_ns == 0: - return 0 - return rng.expovariate(lambd=queries_per_ns) # lambd=1/mean, where mean=latency - - return _fn - - -class Scheduler: - """Base class for query scheduling strategies that control benchmark load patterns. - - Schedulers determine: - 1. Sample ordering (which sample to use next) - 2. Timing delays (when to issue the next query) - - They combine a SampleOrder (what to issue) with a delay function (when to issue) - to produce a stream of (sample_index, delay_ns) pairs. - - Scheduler implementations auto-register via __init_subclass__ by specifying - the load_pattern parameter. This enables runtime selection of schedulers: - - scheduler_cls = Scheduler.get_implementation(LoadPatternType.POISSON) - scheduler = scheduler_cls(runtime_settings, sample_order_cls) - - Built-in schedulers: - - MaxThroughputScheduler: Issues all queries immediately (offline mode) - - PoissonDistributionScheduler: Poisson-distributed delays (online mode) - - ConcurrencyScheduler: Fixed concurrency level (online mode) - - Attributes: - _IMPL_MAP: Class-level registry mapping LoadPatternType to Scheduler classes. - runtime_settings: Runtime configuration (QPS, duration, seeds, etc.). - total_samples_to_issue: Total queries to issue during benchmark. - n_unique_samples: Number of unique samples in dataset. - sample_order: Iterator over sample indices to use. - delay_fn: Function returning delay before next query (nanoseconds). - """ - - # Registry for scheduler implementations (populated via __init_subclass__) - _IMPL_MAP: dict[LoadPatternType, type["Scheduler"]] = {} - - def __init__( - self, - runtime_settings: RuntimeSettings, - sample_order_cls: type[SampleOrder], - ): - """Initialize scheduler with runtime settings and sample ordering strategy. - - Args: - runtime_settings: Runtime configuration containing QPS, duration, seeds. - sample_order_cls: SampleOrder class to use for sample selection. - """ - self.runtime_settings = runtime_settings - - self.total_samples_to_issue = runtime_settings.total_samples_to_issue() - self.n_unique_samples = runtime_settings.n_samples_from_dataset - self.sample_order = iter( - sample_order_cls( - self.total_samples_to_issue, - self.n_unique_samples, - rng=self.runtime_settings.rng_sample_index, - ) - ) - self.delay_fn: Callable[[], int] | None = None # Subclasses must set this - - def __iter__(self): - """Iterate over (sample_index, delay_ns) pairs. - - Yields: - Tuple of (sample_index, delay_ns): - - sample_index: Index of sample to issue next - - delay_ns: Nanoseconds to wait before issuing - """ - for s_idx in self.sample_order: - yield s_idx, self.delay_fn() - - def __init_subclass__(cls, load_pattern: LoadPatternType | None = None, **kwargs): - """Auto-register scheduler implementations. - - Args: - load_pattern: LoadPatternType to bind this scheduler to - - Raises: - ValueError: If load_pattern already registered - """ - super().__init_subclass__(**kwargs) - - if load_pattern is not None: - if load_pattern in Scheduler._IMPL_MAP: - raise ValueError( - f"Cannot bind {cls.__name__} to {load_pattern} - " - f"Already bound to {Scheduler._IMPL_MAP[load_pattern].__name__}" - ) - Scheduler._IMPL_MAP[load_pattern] = cls - - @classmethod - def get_implementation(cls, load_pattern: LoadPatternType) -> type["Scheduler"]: - """Get scheduler implementation for load pattern. - - Args: - load_pattern: LoadPatternType enum - - Returns: - Scheduler subclass - - Raises: - NotImplementedError: If no implementation registered - KeyError: If load_pattern invalid - """ - if load_pattern not in cls._IMPL_MAP: - available_str = ", ".join(p.value for p in cls._IMPL_MAP.keys()) - raise KeyError( - f"No scheduler registered for '{load_pattern.value}'. " - f"Available: {available_str}" - ) - return cls._IMPL_MAP[load_pattern] - - -class MaxThroughputScheduler(Scheduler, load_pattern=LoadPatternType.MAX_THROUGHPUT): - """Offline max throughput scheduler (all queries at t=0). - - Auto-registers for LoadPatternType.MAX_THROUGHPUT. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.delay_fn = uniform_delay_fn(rng=self.runtime_settings.rng_sched) - - -class PoissonDistributionScheduler(Scheduler, load_pattern=LoadPatternType.POISSON): - """Poisson-distributed query scheduler for online benchmarking. - - Simulates realistic client-server network usage by using a Poisson process - to issue queries. The delay between each sample is sampled from an exponential - distribution, centered around the expected latency based on target QPS. - - Use this scheduler for online latency testing with sustained QPS. - - Auto-registers for LoadPatternType.POISSON. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.delay_fn = poisson_delay_fn( - expected_queries_per_second=self.runtime_settings.metric_target.target, - rng=self.runtime_settings.rng_sched, - ) - - -class ConcurrencyScheduler(Scheduler, load_pattern=LoadPatternType.CONCURRENCY): - """Concurrency-based scheduler that maintains fixed concurrent requests. - - Issues queries based on COMPLETION events rather than time delays. - Maintains target concurrency level (e.g., always 32 requests in-flight). - - Auto-registers for LoadPatternType.CONCURRENCY. - """ - - def __init__(self, runtime_settings: RuntimeSettings, sample_order_cls): - super().__init__(runtime_settings, sample_order_cls) - assert runtime_settings.load_pattern is not None - target_concurrency = runtime_settings.load_pattern.target_concurrency - if target_concurrency is None or target_concurrency <= 0: - raise ValueError( - f"target_concurrency must be > 0 for CONCURRENCY load pattern, got {target_concurrency}" - ) - - # Use threading.Condition for concurrency control with explicit counter - self._condition = threading.Condition() - self._inflight = 0 - self._target_concurrency = target_concurrency - - # Register completion hook - free up slot when query completes - SampleEventHandler.register_hook(SampleEvent.COMPLETE, self._release_slot) - - # Unused (required by Scheduler interface) - returns 0 delay - self.delay_fn = lambda: 0 - - def _release_slot(self, result=None): - """Release a concurrency slot and notify waiting threads. - - Args: - result: QueryResult from completed query (unused, required by hook signature) - """ - with self._condition: - self._inflight -= 1 - self._condition.notify() - - def __iter__(self): - """ - Iterate over sample indices to issue. - Yields sample indices until total_samples_to_issue is reached. - - Waits for available concurrency slot before yielding each sample index. - """ - for s_idx in self.sample_order: - with self._condition: - while self._inflight >= self._target_concurrency: - self._condition.wait() - self._inflight += 1 - yield s_idx, 0 diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 79f9feac..ed900910 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -13,212 +13,460 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Async benchmark session: orchestrates phases, issues samples, receives responses. + +See docs/load_generator/design.md for the full design. +""" + from __future__ import annotations +import asyncio import logging -import os -import threading +import time import uuid -from pathlib import Path - -import msgspec.json +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +from typing import Protocol from ..config.runtime_settings import RuntimeSettings +from ..core.record import ( + ErrorEventType, + EventRecord, + SampleEventType, + SessionEventType, +) +from ..core.types import PromptData, Query, QueryResult, StreamChunk from ..dataset_manager.dataset import Dataset -from .load_generator import LoadGenerator, SampleIssuer, SchedulerBasedLoadGenerator -from .scheduler import Scheduler, WithoutReplacementSampleOrder +from .sample_order import create_sample_order +from .strategy import LoadStrategy, create_load_strategy logger = logging.getLogger(__name__) -class BenchmarkSession: +# --------------------------------------------------------------------------- +# Phase configuration +# --------------------------------------------------------------------------- + + +class PhaseType(str, Enum): + """Phase types control tracking and reporting behavior.""" + + PERFORMANCE = "performance" + ACCURACY = "accuracy" + SATURATION = "saturation" + + +@dataclass(frozen=True, slots=True) +class PhaseConfig: + """Configuration for a single benchmark phase.""" + + name: str + runtime_settings: RuntimeSettings + dataset: Dataset + phase_type: PhaseType = PhaseType.PERFORMANCE + + +# --------------------------------------------------------------------------- +# Results +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class PhaseResult: + """Result of a single benchmark phase.""" + + name: str + phase_type: PhaseType + uuid_to_index: dict[str, int] + issued_count: int + start_time_ns: int + end_time_ns: int + + +@dataclass(frozen=True) +class SessionResult: + """Combined results from all phases in a session.""" + + session_id: str + phase_results: list[PhaseResult] + start_time_ns: int + end_time_ns: int + + @property + def perf_results(self) -> list[PhaseResult]: + return [r for r in self.phase_results if r.phase_type == PhaseType.PERFORMANCE] + + @property + def accuracy_results(self) -> list[PhaseResult]: + return [r for r in self.phase_results if r.phase_type == PhaseType.ACCURACY] + + +# --------------------------------------------------------------------------- +# SampleIssuer protocol +# --------------------------------------------------------------------------- + + +class SampleIssuer(Protocol): + """Sends queries to an endpoint and receives responses. + + Matches HTTPEndpointClient's interface: issue (sync ZMQ push), + recv (async ZMQ recv), shutdown. + """ + + def issue(self, query: Query) -> None: ... + async def recv(self) -> QueryResult | StreamChunk | None: ... + def shutdown(self) -> None: ... + + +# --------------------------------------------------------------------------- +# EventRecordPublisher protocol +# --------------------------------------------------------------------------- + + +class EventPublisher(Protocol): + """Publishes EventRecords to the metrics pipeline.""" + + def publish(self, event_record: EventRecord) -> None: ... + + +# --------------------------------------------------------------------------- +# PhaseIssuer +# --------------------------------------------------------------------------- + + +class PhaseIssuer: + """Per-phase state holder that wraps the issue logic. + + Created fresh for each phase. Holds the phase-scoped uuid_to_index map, + inflight counter, and issued count. Strategies call issue(sample_index) + to load data, build a Query, publish ISSUED, and send to the endpoint. + """ + + __slots__ = ( + "_dataset", + "_issuer", + "_publisher", + "_stop_check", + "uuid_to_index", + "inflight", + "issued_count", + ) + def __init__( self, - runtime_settings: RuntimeSettings, - session_id: str | None = None, + dataset: Dataset, + issuer: SampleIssuer, + publisher: EventPublisher, + stop_check: Callable[[], bool], ): - self.logger = logging.getLogger(__name__) - self.runtime_settings = runtime_settings - self.session_id = session_id if session_id else uuid.uuid4().hex - - self.end_event = threading.Event() - self.thread: threading.Thread | None = None + self._dataset = dataset + self._issuer = issuer + self._publisher = publisher + self._stop_check = stop_check + self.uuid_to_index: dict[str, int] = {} + self.inflight: int = 0 + self.issued_count: int = 0 + + def issue(self, sample_index: int) -> str | None: + """Load data, build Query, publish ISSUED, send to endpoint. + + Returns query_id on success, None if session is stopping. + + Note: load_sample() runs synchronously before the ISSUED timestamp. + For accurate timing, datasets MUST be pre-loaded into memory. + Disk-backed datasets will inflate timing and delay subsequent issues. + """ + if self._stop_check(): + return None + query_id = uuid.uuid4().hex + data = self._dataset.load_sample(sample_index) + query = Query(id=query_id, data=data) + self.uuid_to_index[query_id] = sample_index + ts = time.monotonic_ns() + prompt_data: PromptData + if isinstance(data, dict): + token_ids = data.get("input_tokens") or data.get("token_ids") + prompt_data = PromptData( + text=data.get("prompt"), + token_ids=tuple(token_ids) if token_ids is not None else None, + ) + else: + prompt_data = PromptData() + self._publisher.publish( + EventRecord( + event_type=SampleEventType.ISSUED, + timestamp_ns=ts, + sample_uuid=query_id, + data=prompt_data, + ) + ) + self._issuer.issue(query) + self.inflight += 1 + self.issued_count += 1 + return query_id - # CPython GIL provides atomic boolean writes, no need for threading.Event() - self.stop_requested = False - # Will be populated after the test finishes by _run_test - self.report = None +# --------------------------------------------------------------------------- +# BenchmarkSession +# --------------------------------------------------------------------------- - self.sample_uuid_map: dict[str, dict[str, int]] | None = None - @property - def is_running(self): - return self.thread is not None and self.thread.is_alive() +class BenchmarkSession: + """Async benchmark orchestrator. Single thread, single event loop. - def stop(self) -> None: - """Signal the session to stop early.""" - self.stop_requested = True - # wakeup _run_test if needed, short-circuit SHUTDOWN_POLL_INTERVAL_S - self.end_event.set() + Runs phases sequentially. Each phase gets its own PhaseIssuer and + LoadStrategy. The receiver coroutine runs concurrently throughout, + processing responses and routing completions to the active strategy. + """ - def _run_test( + def __init__( self, - perf_test_generator: LoadGenerator, - accuracy_test_generators: dict[str, LoadGenerator] | None = None, - report_dir: os.PathLike | None = None, + issuer: SampleIssuer, + event_publisher: EventPublisher, + loop: asyncio.AbstractEventLoop, + on_sample_complete: Callable[[QueryResult | StreamChunk], None] | None = None, + session_id: str | None = None, ): - try: - for _ in perf_test_generator: - pass + self._issuer = issuer + self._publisher = event_publisher + self._loop = loop + self._on_sample_complete = on_sample_complete + self.session_id = session_id or uuid.uuid4().hex + + # Mutable state + self._stop_requested = False + self._done = False + self._current_phase_issuer: PhaseIssuer | None = None + self._current_strategy: LoadStrategy | None = None + self._recv_task: asyncio.Task | None = None + self._strategy_task: asyncio.Task | None = None + self._drain_event = asyncio.Event() - self.logger.info("All performance samples issued") + def stop(self) -> None: + """Signal early termination. Safe to call from signal handler. - if accuracy_test_generators: - for _, generator in accuracy_test_generators.items(): - for _ in generator: - pass + Cancels the running strategy task to unblock strategies that may be + waiting on semaphores or other async primitives. Also sets the drain + event to unblock _drain_inflight if it's waiting for responses. + """ + self._stop_requested = True + self._drain_event.set() + if self._strategy_task and not self._strategy_task.done(): + self._strategy_task.cancel() - self.logger.info("All accuracy samples issued") + async def run(self, phases: list[PhaseConfig]) -> SessionResult: + """Run all benchmark phases sequentially. - # TODO: Wire in EventPublisherService + ServiceLauncher (Phase 5) - # For now, no event recording or report generation. + Returns SessionResult with per-phase results. + """ + session_start = time.monotonic_ns() + self._publish_session_event(SessionEventType.STARTED) - except Exception as e: - logger.error(f"Error running benchmark session: {e}") - raise e + self._recv_task = asyncio.create_task(self._receive_responses()) + phase_results: list[PhaseResult] = [] - # Consolidate UUID->index mappings - perf_name = ( - perf_test_generator.name if perf_test_generator.name else "performance" + try: + for phase in phases: + if self._stop_requested: + break + result = await self._run_phase(phase) + if result is not None: + phase_results.append(result) + finally: + self._done = True + if self._recv_task and not self._recv_task.done(): + self._recv_task.cancel() + try: + await self._recv_task + except asyncio.CancelledError: + pass + self._publish_session_event(SessionEventType.ENDED) + + return SessionResult( + session_id=self.session_id, + phase_results=phase_results, + start_time_ns=session_start, + end_time_ns=time.monotonic_ns(), ) - sample_idx_map = { - perf_name: perf_test_generator.uuid_to_index_map, - } - if accuracy_test_generators: - for default_name, generator in accuracy_test_generators.items(): - name = generator.name if generator.name else default_name - sample_idx_map[name] = generator.uuid_to_index_map - self.sample_uuid_map = sample_idx_map - - # Save runtime settings and UUID map if report_dir provided - if report_dir: - Path(report_dir).mkdir(parents=True, exist_ok=True) - - rt_settings_data: dict[str, int | str | None] = { - "min_duration_ms": self.runtime_settings.min_duration_ms, - "max_duration_ms": self.runtime_settings.max_duration_ms, - "n_samples_from_dataset": self.runtime_settings.n_samples_from_dataset, - "n_samples_to_issue": self.runtime_settings.n_samples_to_issue, - "min_sample_count": self.runtime_settings.min_sample_count, - "total_samples_to_issue": self.runtime_settings.total_samples_to_issue(), - } - has_model = hasattr(self.runtime_settings, "model") - if has_model: - model = getattr(self.runtime_settings, "model", None) - if model is not None: - rt_settings_data["model"] = ( - model if isinstance(model, str) else str(model.name) - ) - with (Path(report_dir) / "runtime_settings.json").open("w") as f: - f.write( - msgspec.json.format( - msgspec.json.encode(dict(sorted(rt_settings_data.items()))), - indent=2, - ).decode("utf-8") - ) + async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: + """Run a single phase. Returns PhaseResult or None for saturation.""" + logger.info("Starting phase: %s (%s)", phase.name, phase.phase_type.value) + phase_start = time.monotonic_ns() - with (Path(report_dir) / "sample_idx_map.json").open("w") as f: - f.write(msgspec.json.encode(self.sample_uuid_map).decode("utf-8")) + # Create per-phase state + sample_order = create_sample_order(phase.runtime_settings) + strategy = create_load_strategy( + phase.runtime_settings, self._loop, sample_order + ) + phase_issuer = PhaseIssuer( + dataset=phase.dataset, + issuer=self._issuer, + publisher=self._publisher, + stop_check=self._make_stop_check(phase.runtime_settings, phase_start), + ) - def wait_for_test_end(self, timeout: float | None = None) -> bool: - """ - Join the test thread and return True if the test completed, False if it timed out. + self._current_phase_issuer = phase_issuer + self._current_strategy = strategy - Args: - timeout: The maximum time to wait for the test to complete. If None, wait indefinitely. + # Performance phases get tracking events + if phase.phase_type == PhaseType.PERFORMANCE: + self._publish_session_event(SessionEventType.START_PERFORMANCE_TRACKING) - Returns: - bool: True if the test thread has completed, False if it timed out. - """ - if not self.thread: - return False - self.thread.join(timeout=timeout) - return not self.thread.is_alive() + # Execute the strategy as a task so it can be cancelled on transport close + self._strategy_task = asyncio.create_task(strategy.execute(phase_issuer)) + try: + await self._strategy_task + except asyncio.CancelledError: + logger.info("Strategy cancelled for phase %s", phase.name) + finally: + self._strategy_task = None + + # Drain in-flight (skip for saturation — keep concurrency hot) + if phase.phase_type != PhaseType.SATURATION: + await self._drain_inflight(phase_issuer) + + if phase.phase_type == PhaseType.PERFORMANCE: + self._publish_session_event(SessionEventType.STOP_PERFORMANCE_TRACKING) + + phase_end = time.monotonic_ns() + logger.info( + "Phase %s complete: %d samples issued", + phase.name, + phase_issuer.issued_count, + ) - @classmethod - def start( - cls, - runtime_settings: RuntimeSettings, - dataset: Dataset, - sample_issuer: SampleIssuer, - scheduler: Scheduler, - *args, - accuracy_datasets: list[Dataset] | None = None, - load_generator_cls: type[LoadGenerator] = SchedulerBasedLoadGenerator, - name: str | None = None, - report_dir: os.PathLike | None = None, - ) -> BenchmarkSession: - """Start a new BenchmarkSession in a thread. - - Args: - runtime_settings: The runtime settings to use for the session. - dataset: The dataset to use for the performance test. - sample_issuer: The sample issuer to use for the session. - scheduler: The scheduler to use for the session. - accuracy_datasets: The datasets to use for the accuracy tests. - load_generator_cls: The load generator class to use for the session. - name: The name of the session. - report_dir: The path to save the report to. If None, no report will be saved. - - Returns: - The new BenchmarkSession. + # Saturation phases produce no result + if phase.phase_type == PhaseType.SATURATION: + return None + + return PhaseResult( + name=phase.name, + phase_type=phase.phase_type, + uuid_to_index=phase_issuer.uuid_to_index, + issued_count=phase_issuer.issued_count, + start_time_ns=phase_start, + end_time_ns=phase_end, + ) + + async def _drain_inflight(self, phase_issuer: PhaseIssuer) -> None: + """Wait for all in-flight responses from this phase to complete. + + Currently, there is no timeout for the drain step. In the future, + we can possibly add a dynamic timeout based on the rate of completion + throughout the current phase.""" + if phase_issuer.inflight <= 0 or self._stop_requested: + return + logger.info("Draining %d in-flight responses...", phase_issuer.inflight) + self._drain_event.clear() + await self._drain_event.wait() + + async def _receive_responses(self) -> None: + """Receive responses from the issuer. Runs as a concurrent task.""" + while not self._done: + resp = await self._issuer.recv() + if resp is None: + # Transport closed unexpectedly — trigger stop so strategy + # and drain don't hang waiting for responses that will never arrive. + logger.warning("Issuer recv() returned None — transport closed") + self._stop_requested = True + self._drain_event.set() # Unblock _drain_inflight + # Cancel the strategy task if it's blocked (e.g., ConcurrencyStrategy + # awaiting sem.acquire() that will never be released). + if self._strategy_task and not self._strategy_task.done(): + self._strategy_task.cancel() + break + self._handle_response(resp) + + def _handle_response(self, resp: QueryResult | StreamChunk) -> None: + """Route a response to the appropriate handler. + + Transport contract for streaming: the worker sends intermediate + StreamChunk messages for timing events, then a final QueryResult + with accumulated output for completion. """ - session = cls(runtime_settings, session_id=name) - load_generator = load_generator_cls(sample_issuer, dataset, scheduler, *args) # type: ignore[arg-type] - - # Create accuracy test generators - accuracy_test_generators = None - if accuracy_datasets: - accuracy_test_generators = {} - for ds in accuracy_datasets: - if hasattr(ds.__class__, "DATASET_ID"): - ds_name = ds.__class__.DATASET_ID - else: - ds_name = ds.__class__.__name__ - - # Create accuracy dataset specific runtime settings - acc_rt_settings = RuntimeSettings( - metric_target=runtime_settings.metric_target, - reported_metrics=runtime_settings.reported_metrics, - min_duration_ms=0, - max_duration_ms=None, - n_samples_from_dataset=ds.num_samples(), - n_samples_to_issue=ds.num_samples() * ds.repeats, - min_sample_count=ds.num_samples() * ds.repeats, - rng_sched=runtime_settings.rng_sched, - rng_sample_index=runtime_settings.rng_sample_index, - load_pattern=runtime_settings.load_pattern, + phase_issuer = self._current_phase_issuer + + if isinstance(resp, QueryResult): + query_id = resp.id + self._publisher.publish( + EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=resp.completed_at + if isinstance(resp.completed_at, int) + else time.monotonic_ns(), + sample_uuid=query_id, + data=resp.response_output, ) - acc_sched = scheduler.__class__( - acc_rt_settings, WithoutReplacementSampleOrder + ) + if resp.error is not None: + self._publisher.publish( + EventRecord( + event_type=ErrorEventType.GENERIC, + timestamp_ns=time.monotonic_ns(), + sample_uuid=query_id, + data=resp.error, + ) ) - - accuracy_test_generators[ds_name] = load_generator_cls( - sample_issuer, - ds, - acc_sched, # type: ignore[arg-type] - *args, + if phase_issuer is not None and query_id in phase_issuer.uuid_to_index: + phase_issuer.inflight -= 1 + if phase_issuer.inflight <= 0: + self._drain_event.set() + if self._current_strategy: + self._current_strategy.on_query_complete(query_id) + if self._on_sample_complete: + self._on_sample_complete(resp) + + elif isinstance(resp, StreamChunk): + is_first = resp.metadata.get("first_chunk", False) + event_type = ( + SampleEventType.RECV_FIRST + if is_first + else SampleEventType.RECV_NON_FIRST + ) + self._publisher.publish( + EventRecord( + event_type=event_type, + timestamp_ns=time.monotonic_ns(), + sample_uuid=resp.id, ) + ) + + def _make_stop_check( + self, settings: RuntimeSettings, phase_start_ns: int + ) -> Callable[[], bool]: + """Create a stop-check closure for a phase. + + Reads self._current_phase_issuer at call time (not creation time). + Invariant: _current_phase_issuer must not change while a phase's + strategy is executing. This is guaranteed by sequential phase execution. + """ + max_duration_ns = ( + settings.max_duration_ms * 1_000_000 + if settings.max_duration_ms is not None + else 0 + ) + total_samples = settings.total_samples_to_issue() + + def check() -> bool: + if self._stop_requested: + return True + if ( + self._current_phase_issuer + and self._current_phase_issuer.issued_count >= total_samples + ): + return True + if ( + max_duration_ns > 0 + and (time.monotonic_ns() - phase_start_ns) >= max_duration_ns + ): + return True + return False + + return check - session.thread = threading.Thread( - target=session._run_test, - args=(load_generator,), - kwargs={ - "accuracy_test_generators": accuracy_test_generators, - "report_dir": report_dir, - }, + def _publish_session_event(self, event_type: SessionEventType) -> None: + self._publisher.publish( + EventRecord(event_type=event_type, timestamp_ns=time.monotonic_ns()) ) - session.thread.start() - return session diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py new file mode 100644 index 00000000..d9b77970 --- /dev/null +++ b/src/inference_endpoint/load_generator/strategy.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Load strategies: controls the pacing of sample issuance. + +Three implementations, each optimized for a different load pattern: +- TimedIssueStrategy: Poisson (loop.call_at or run_in_executor) +- BurstStrategy: Max throughput (loop.call_soon) +- ConcurrencyStrategy: Fixed concurrency (asyncio.Semaphore) + +See docs/load_generator/design.md for benchmark data and design rationale. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Callable, Iterator +from time import monotonic_ns +from typing import Protocol + +from ..config.runtime_settings import RuntimeSettings +from ..config.schema import LoadPatternType +from .delay import make_delay_fn +from .sample_order import SampleOrder, create_sample_order + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# LoadStrategy Protocol +# --------------------------------------------------------------------------- + + +class PhaseIssuerProtocol(Protocol): + """Minimal interface that strategies see for issuing samples.""" + + def issue(self, sample_index: int) -> str | None: + """Issue a sample. Returns query_id, or None if the session is stopping.""" + ... + + issued_count: int + + +class LoadStrategy(Protocol): + """Controls the pacing strategy for issuing requests. + + Strategies call phase_issuer.issue(sample_index) to issue each sample. + issue() returns query_id on success, None when the session should stop. + """ + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + """Drive sample issuance. Returns count of samples issued.""" + ... + + def on_query_complete(self, query_id: str) -> None: + """Called by session on each QueryResult. Default: no-op. + + Used by ConcurrencyStrategy to release semaphore slots. + """ + ... + + +# --------------------------------------------------------------------------- +# TimedIssueStrategy (Poisson) +# --------------------------------------------------------------------------- + + +def _busy_wait_until(target_ns: int) -> None: + """Busy-wait in a thread pool thread until target timestamp.""" + while monotonic_ns() < target_ns: + pass + + +class TimedIssueStrategy: + """Schedule-driven load strategy with inter-arrival delays. + + Default mode (call_at): schedules each issue as an event loop callback + at the precise target time. Zero GIL contention, sub-ms precision. + Good for <= 50k QPS. + + Executor mode (opt-in): offloads busy-wait to thread pool for sub-100us + precision. Introduces GIL contention that adds latency at low QPS. + """ + + def __init__( + self, + delay_fn: Callable[[], int], + sample_order: Iterator[int], + loop: asyncio.AbstractEventLoop, + use_executor: bool = False, + ): + self._delay_fn = delay_fn + self._sample_order = sample_order + self._loop = loop + self._use_executor = use_executor + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + if self._use_executor: + return await self._execute_executor(phase_issuer) + return await self._execute_call_at(phase_issuer) + + def on_query_complete(self, query_id: str) -> None: + pass + + async def _execute_call_at(self, phase_issuer: PhaseIssuerProtocol) -> int: + done = asyncio.Event() + start_time = self._loop.time() + cumulative_s = 0.0 + + def schedule_next(): + nonlocal cumulative_s, error + try: + idx = next(self._sample_order, None) + if idx is None: + done.set() + return + cumulative_s += self._delay_fn() / 1e9 + self._loop.call_at(start_time + cumulative_s, fire, idx) + except Exception as exc: + error = exc + done.set() + + error: BaseException | None = None + + def fire(idx: int): + nonlocal error + try: + if phase_issuer.issue(idx) is None: + done.set() + return + schedule_next() + except Exception as exc: + error = exc + done.set() + + schedule_next() + await done.wait() + if error is not None: + raise error + return phase_issuer.issued_count + + async def _execute_executor(self, phase_issuer: PhaseIssuerProtocol) -> int: + start = monotonic_ns() + cumulative = 0 + for idx in self._sample_order: + cumulative += self._delay_fn() + target = start + cumulative + now = monotonic_ns() + if target > now: + await self._loop.run_in_executor(None, _busy_wait_until, target) + if phase_issuer.issue(idx) is None: + break + return phase_issuer.issued_count + + +# --------------------------------------------------------------------------- +# BurstStrategy (Max Throughput) +# --------------------------------------------------------------------------- + + +class BurstStrategy: + """Fire-as-fast-as-possible strategy using loop.call_soon. + + Each issue is scheduled as an event loop callback, yielding between + issues so the receiver coroutine can process responses. Achieves + 100k+ QPS without starving the event loop. + """ + + def __init__( + self, + sample_order: Iterator[int], + loop: asyncio.AbstractEventLoop, + ): + self._sample_order = sample_order + self._loop = loop + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + done = asyncio.Event() + error: BaseException | None = None + + def issue_next(): + nonlocal error + try: + idx = next(self._sample_order, None) + if idx is None or phase_issuer.issue(idx) is None: + done.set() + return + self._loop.call_soon(issue_next) + except Exception as exc: + error = exc + done.set() + + self._loop.call_soon(issue_next) + await done.wait() + if error is not None: + raise error + return phase_issuer.issued_count + + def on_query_complete(self, query_id: str) -> None: + pass + + +# --------------------------------------------------------------------------- +# ConcurrencyStrategy +# --------------------------------------------------------------------------- + + +class ConcurrencyStrategy: + """Completion-driven strategy maintaining fixed concurrent requests. + + Uses asyncio.Semaphore for gating: acquire before issue, release on + completion via on_query_complete(). With eager_task_factory, the woken + waiter executes synchronously within release(), minimizing jitter. + """ + + def __init__( + self, + target_concurrency: int, + sample_order: Iterator[int], + ): + if target_concurrency <= 0: + raise ValueError( + f"target_concurrency must be > 0, got {target_concurrency}" + ) + self._target = target_concurrency + self._sem = asyncio.Semaphore(target_concurrency) + self._sample_order = sample_order + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + for idx in self._sample_order: + await self._sem.acquire() + if phase_issuer.issue(idx) is None: + self._sem.release() + break + return phase_issuer.issued_count + + def on_query_complete(self, query_id: str) -> None: + self._sem.release() + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_load_strategy( + runtime_settings: RuntimeSettings, + loop: asyncio.AbstractEventLoop, + sample_order: SampleOrder | None = None, + use_executor: bool = False, +) -> LoadStrategy: + """Create a LoadStrategy from RuntimeSettings. + + Args: + runtime_settings: Runtime configuration with load_pattern. + loop: Event loop for scheduling callbacks. + sample_order: Sample ordering iterator. If None, created from settings. + use_executor: For Poisson, use run_in_executor for sub-100us precision. + + Returns: + LoadStrategy implementation for the configured load pattern. + """ + lp = runtime_settings.load_pattern + if lp is None: + raise ValueError("RuntimeSettings.load_pattern must not be None") + + if sample_order is None: + sample_order = create_sample_order(runtime_settings) + + match lp.type: + case LoadPatternType.MAX_THROUGHPUT: + return BurstStrategy(sample_order, loop) + + case LoadPatternType.POISSON: + delay_fn = make_delay_fn(lp, runtime_settings.rng_sched) + return TimedIssueStrategy( + delay_fn, sample_order, loop, use_executor=use_executor + ) + + case LoadPatternType.CONCURRENCY: + if lp.target_concurrency is None or lp.target_concurrency <= 0: + raise ValueError( + "Concurrency load pattern requires target_concurrency > 0" + ) + return ConcurrencyStrategy(lp.target_concurrency, sample_order) + + case _: + raise ValueError(f"Unsupported load pattern type: {lp.type}") diff --git a/src/inference_endpoint/metrics/report.py b/src/inference_endpoint/metrics/report.py index 4f639605..e24d954b 100644 --- a/src/inference_endpoint/metrics/report.py +++ b/src/inference_endpoint/metrics/report.py @@ -169,15 +169,17 @@ def _summarize(key: str) -> dict: return {} version_info = get_version_info() - duration_ns = _counter("duration_ns") + duration_ns = _counter("tracked_duration_ns") return cls( version=str(version_info.get("version", "unknown")), git_sha=version_info.get("git_sha"), - test_started_at=_counter("test_started_at"), - n_samples_issued=_counter("n_samples_issued"), - n_samples_completed=_counter("n_samples_completed"), - n_samples_failed=_counter("n_samples_failed"), + test_started_at=0, # TODO: add test_started_at counter to aggregator + n_samples_issued=_counter("tracked_samples_issued"), + n_samples_completed=_counter("tracked_samples_completed"), + # TODO: Add tracked_samples_failed to MetricCounterKey. + # For now, total_samples_failed is the best available. + n_samples_failed=_counter("total_samples_failed"), duration_ns=duration_ns if duration_ns > 0 else None, ttft=_summarize("ttft_ns"), tpot=_summarize("tpot_ns"), diff --git a/src/inference_endpoint/openai/accumulator.py b/src/inference_endpoint/openai/accumulator.py index 4400766c..6cb23ed8 100644 --- a/src/inference_endpoint/openai/accumulator.py +++ b/src/inference_endpoint/openai/accumulator.py @@ -57,10 +57,8 @@ def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: chunk = StreamChunk( id=self.query_id, response_chunk=content, - is_complete=False, metadata={ "first_chunk": not self.first_chunk_sent, - "final_chunk": False, }, ) self.first_chunk_sent = True diff --git a/src/inference_endpoint/profiling/line_profiler.py b/src/inference_endpoint/profiling/line_profiler.py index 79aaa144..56c2d659 100644 --- a/src/inference_endpoint/profiling/line_profiler.py +++ b/src/inference_endpoint/profiling/line_profiler.py @@ -28,6 +28,11 @@ import io import os import sys + +try: + from line_profiler import LineProfiler +except ImportError: + LineProfiler = None from collections.abc import Callable from pathlib import Path from typing import Any, Optional, TypeVar @@ -68,18 +73,15 @@ def __init__(self): self._atexit_registered = False if self.enabled: - try: - from line_profiler import LineProfiler - - self.profiler = LineProfiler() - self.profiler.enable() - atexit.register(self._safe_cleanup) - self._atexit_registered = True - except ImportError as e: + if LineProfiler is None: raise ImportError( f"line_profiler not installed but {ENV_VAR_ENABLE_LINE_PROFILER}={enable_profiler} is set. " f"Install with: pip install line_profiler" - ) from e + ) + self.profiler = LineProfiler() + self.profiler.enable() + atexit.register(self._safe_cleanup) + self._atexit_registered = True def _safe_cleanup(self): """Safe cleanup wrapper that suppresses all errors during atexit.""" diff --git a/src/inference_endpoint/profiling/pytest_profiling_plugin.py b/src/inference_endpoint/profiling/pytest_profiling_plugin.py index 71641cf9..3a268061 100644 --- a/src/inference_endpoint/profiling/pytest_profiling_plugin.py +++ b/src/inference_endpoint/profiling/pytest_profiling_plugin.py @@ -26,6 +26,7 @@ import atexit import glob import os +import shutil import sys from inference_endpoint.profiling import shutdown @@ -102,8 +103,6 @@ def _print_worker_profiles(): def _cleanup_profile_files(output_file: str): """Remove profile directory and files after displaying results.""" try: - import shutil - profile_dir = os.path.dirname(output_file) if profile_dir and os.path.exists(profile_dir): shutil.rmtree(profile_dir, ignore_errors=True) diff --git a/src/inference_endpoint/sglang/accumulator.py b/src/inference_endpoint/sglang/accumulator.py index 29579e7f..081106eb 100644 --- a/src/inference_endpoint/sglang/accumulator.py +++ b/src/inference_endpoint/sglang/accumulator.py @@ -65,7 +65,6 @@ def add_chunk(self, delta: SGLangSSEDelta) -> StreamChunk | None: chunk = StreamChunk( id=self.query_id, response_chunk=content_diff, - is_complete=False, metadata=metadata, ) self.first_chunk_sent = True diff --git a/src/inference_endpoint/testing/echo_server.py b/src/inference_endpoint/testing/echo_server.py index a957c39a..6555f2e6 100644 --- a/src/inference_endpoint/testing/echo_server.py +++ b/src/inference_endpoint/testing/echo_server.py @@ -29,6 +29,7 @@ from inference_endpoint.core.types import QueryResult, TextModelOutput from inference_endpoint.openai.openai_adapter import OpenAIAdapter from inference_endpoint.openai.openai_types_gen import CreateChatCompletionRequest +from inference_endpoint.utils.logging import setup_logging class HTTPServer: @@ -427,8 +428,6 @@ def main(): """ # - from inference_endpoint.utils.logging import setup_logging - setup_logging() parser = create_parser() args = parser.parse_args() diff --git a/src/inference_endpoint/testing/max_throughput_server.py b/src/inference_endpoint/testing/max_throughput_server.py index 63bacb38..12f3267d 100644 --- a/src/inference_endpoint/testing/max_throughput_server.py +++ b/src/inference_endpoint/testing/max_throughput_server.py @@ -30,12 +30,14 @@ import argparse import asyncio +import gc import multiprocessing import multiprocessing.sharedctypes import multiprocessing.synchronize import os import signal import socket +import sys import threading import time @@ -301,8 +303,6 @@ def _worker( global _req_counter, _resp_counter, _byte_counter _req_counter, _resp_counter, _byte_counter = counters - import gc - gc.disable() uvloop.install() @@ -329,8 +329,6 @@ def protocol_factory(): try: asyncio.run(run()) except Exception as exc: - import sys - print( f"[MaxThroughputServer] Worker {wid} failed: {exc}", file=sys.stderr, diff --git a/src/inference_endpoint/testing/variable_throughput_server.py b/src/inference_endpoint/testing/variable_throughput_server.py index fae8d1a9..b640b6b4 100644 --- a/src/inference_endpoint/testing/variable_throughput_server.py +++ b/src/inference_endpoint/testing/variable_throughput_server.py @@ -49,6 +49,7 @@ import argparse import asyncio +import gc import math import multiprocessing import multiprocessing.sharedctypes @@ -57,8 +58,10 @@ import random import signal import socket +import sys import threading import time +import warnings import httptools import uvloop @@ -446,8 +449,6 @@ def _worker( global _req_counter, _resp_counter, _byte_counter _req_counter, _resp_counter, _byte_counter = counters - import gc - gc.disable() uvloop.install() @@ -501,8 +502,6 @@ def protocol_factory(): try: asyncio.run(run()) except Exception as exc: - import sys - print( f"[VariableResponseServer] Worker {wid} failed: {exc}", file=sys.stderr, @@ -648,8 +647,6 @@ def __init__( if max_concurrency > 0: self._max_concurrency_per_worker = max(1, max_concurrency // num_workers) if max_concurrency < num_workers: - import warnings - warnings.warn( f"max_concurrency ({max_concurrency}) < num_workers ({num_workers}): " f"each worker gets 1 slot, effective total={num_workers} exceeds cap.", diff --git a/src/inference_endpoint/utils/benchmark_httpclient.py b/src/inference_endpoint/utils/benchmark_httpclient.py index 3785af21..cb0e4ecb 100644 --- a/src/inference_endpoint/utils/benchmark_httpclient.py +++ b/src/inference_endpoint/utils/benchmark_httpclient.py @@ -37,6 +37,8 @@ import time from dataclasses import dataclass +import uvloop + from inference_endpoint.core.types import Query, QueryResult from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.endpoint_client.cpu_affinity import ( @@ -50,6 +52,19 @@ build_response, ) +try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.colors as mcolors + import matplotlib.pyplot as plt + import matplotlib.ticker as ticker +except ImportError: + matplotlib = None + mcolors = None + plt = None + ticker = None + # Suppress transformers "no framework found" warning (only tokenizers used) os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") @@ -1054,14 +1069,7 @@ def generate_sweep_plot( 4 params: MxNxK facet grid — rows=param3, columns=param4. Where N = 2 (non-streaming) or 3 (streaming, adds SSE Rate). """ - try: - import matplotlib - - matplotlib.use("Agg") - import matplotlib.colors as mcolors - import matplotlib.pyplot as plt - import matplotlib.ticker as ticker - except ImportError: + if plt is None: print("\nMatplotlib not installed. Skipping plot generation.") print(" Install with: pip install matplotlib") return @@ -1446,9 +1454,6 @@ def main() -> None: ) gc.set_threshold(70000, 10, 100) - - import uvloop - uvloop.install() server: MaxThroughputServer | None = None diff --git a/tests/conftest.py b/tests/conftest.py index e49d5163..79972aff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,10 +31,8 @@ from inference_endpoint import metrics from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import LoadPattern, LoadPatternType -from inference_endpoint.core.types import TextModelOutput from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat from inference_endpoint.dataset_manager.transforms import ColumnRemap -from inference_endpoint.load_generator.sample import SampleEventHandler from inference_endpoint.testing.docker_server import DockerServer from inference_endpoint.testing.echo_server import EchoServer, HTTPServer @@ -238,8 +236,6 @@ def fake_outputs(sample_uuids): } - - class CharacterTokenizer: def tokenize(self, text: str) -> list[str]: return list(text) @@ -470,11 +466,3 @@ def concurrency_runtime_settings(random_seed, target_concurrency): type=LoadPatternType.CONCURRENCY, target_concurrency=target_concurrency ), ) - - -@pytest.fixture -def clean_sample_event_hooks(): - """Fixture to ensure SampleEventHandler hooks are cleared before and after each test.""" - SampleEventHandler.clear_hooks() - yield SampleEventHandler - SampleEventHandler.clear_hooks() diff --git a/tests/futures_client.py b/tests/futures_client.py index 24bcf05d..871378e1 100644 --- a/tests/futures_client.py +++ b/tests/futures_client.py @@ -69,7 +69,7 @@ async def _handle_responses(self): break # None signals transport closed - exit handler match response: - case StreamChunk(is_complete=False): + case StreamChunk(): # Intermediate stream chunk - future stays pending pass diff --git a/tests/integration/commands/test_accuracy_pipeline.py b/tests/integration/commands/test_accuracy_pipeline.py new file mode 100644 index 00000000..8574599f --- /dev/null +++ b/tests/integration/commands/test_accuracy_pipeline.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test: full accuracy scoring pipeline with echo server. + +The echo server returns the user message content unchanged. We create a +dataset where some prompts match their ground_truth (correct) and some +don't (incorrect), then verify the scorer produces the expected accuracy. +""" + +import json +from pathlib import Path + +import msgspec.json +import pandas as pd +import pytest +from inference_endpoint.commands.benchmark.execute import run_benchmark +from inference_endpoint.config.schema import ( + AccuracyConfig, + BenchmarkConfig, + DatasetType, + EndpointConfig, + LoadPattern, + LoadPatternType, + ModelParams, + RuntimeConfig, + Settings, + StreamingMode, + TestMode, + TestType, +) +from inference_endpoint.config.schema import Dataset as DatasetConfig +from inference_endpoint.endpoint_client.config import HTTPClientConfig + + +def _create_accuracy_dataset(tmp_path: Path) -> Path: + """Create a CSV dataset with some correct and some incorrect ground truths. + + The echo server returns the prompt verbatim. So: + - If ground_truth == prompt → score 1.0 (correct) + - If ground_truth != prompt → score 0.0 (incorrect) + + Dataset: 5 samples, 3 correct + 2 incorrect = 60% accuracy. + """ + data = { + "prompt": [ + "alpha", # correct: echo returns "alpha", ground_truth is "alpha" + "beta", # correct + "gamma", # correct + "What is the answer?", # INCORRECT: echo returns prompt, ground_truth is "42" + "Tell me a joke", # INCORRECT: echo returns prompt, ground_truth is "knock knock" + ], + "ground_truth": [ + "alpha", + "beta", + "gamma", + "42", + "knock knock", + ], + } + df = pd.DataFrame(data) + csv_path = tmp_path / "accuracy_dataset.csv" + df.to_csv(csv_path, index=False) + return csv_path + + +def _create_perf_dataset(tmp_path: Path) -> Path: + """Create a minimal perf dataset (CSV with prompt column).""" + data = {"prompt": ["hello"] * 3} + df = pd.DataFrame(data) + csv_path = tmp_path / "perf_dataset.csv" + df.to_csv(csv_path, index=False) + return csv_path + + +@pytest.mark.integration +class TestAccuracyPipeline: + def test_accuracy_scoring_with_echo_server( + self, mock_http_echo_server, tmp_path, caplog + ): + """Full end-to-end: perf phase + accuracy phase + scoring. + + Expected: 3/5 correct = 60% accuracy (0.6 score). + """ + perf_path = _create_perf_dataset(tmp_path) + acc_path = _create_accuracy_dataset(tmp_path) + + report_dir = tmp_path / "report" + config = BenchmarkConfig( + type=TestType.OFFLINE, + endpoint_config=EndpointConfig(endpoints=[mock_http_echo_server.url]), + model_params=ModelParams(name="echo-server", streaming=StreamingMode.OFF), + datasets=[ + DatasetConfig( + path=str(perf_path), + type=DatasetType.PERFORMANCE, + ), + DatasetConfig( + name="echo_accuracy", + path=str(acc_path), + type=DatasetType.ACCURACY, + accuracy_config=AccuracyConfig( + eval_method="string_match", + ground_truth="ground_truth", + extractor="identity_extractor", + ), + ), + ], + settings=Settings( + runtime=RuntimeConfig(min_duration_ms=0), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + client=HTTPClientConfig( + num_workers=1, warmup_connections=0, max_connections=10 + ), + ), + report_dir=str(report_dir), + ) + + with caplog.at_level("INFO"): + run_benchmark(config, TestMode.BOTH) + + # Verify scoring artifacts were written + assert (report_dir / "sample_idx_map.json").exists() + assert (report_dir / "events.jsonl").exists() + + # Verify sample_idx_map has both phases + with (report_dir / "sample_idx_map.json").open("rb") as f: + idx_map = msgspec.json.decode(f.read()) + assert "performance" in idx_map + assert "echo_accuracy" in idx_map + assert len(idx_map["echo_accuracy"]) == 5 # 5 accuracy samples + + # Verify events.jsonl has COMPLETE events (EventRecord format: "sample.complete") + events_path = report_dir / "events.jsonl" + with events_path.open() as f: + events = [msgspec.json.decode(line.strip()) for line in f if line.strip()] + complete_events = [ + e for e in events if e.get("event_type") == "sample.complete" + ] + # Should have both perf (3) and accuracy (5) completions + assert len(complete_events) == 8 + + # Verify results.json was written with accuracy scores + results_path = report_dir / "results.json" + assert results_path.exists() + with results_path.open() as f: + results = json.load(f) + + assert "accuracy_scores" in results + assert "echo_accuracy" in results["accuracy_scores"] + score_data = results["accuracy_scores"]["echo_accuracy"] + score = score_data["score"] + + # 3 correct out of 5 = 0.6 accuracy + assert abs(score - 0.6) < 0.01, f"Expected 0.6, got {score}" + + # Verify logs mention scoring + assert "Score for echo_accuracy" in caplog.text diff --git a/tests/integration/commands/test_benchmark_command.py b/tests/integration/commands/test_benchmark_command.py index 3a4c1d64..a87aa669 100644 --- a/tests/integration/commands/test_benchmark_command.py +++ b/tests/integration/commands/test_benchmark_command.py @@ -83,7 +83,7 @@ def test_offline_benchmark( assert "Completed in" in caplog.text assert "successful" in caplog.text assert "QPS:" in caplog.text - assert "MaxThroughputScheduler" in caplog.text + assert "Starting phase:" in caplog.text @pytest.mark.integration @pytest.mark.parametrize("streaming", [StreamingMode.OFF, StreamingMode.ON]) @@ -102,8 +102,7 @@ def test_online_benchmark( assert "Completed in" in caplog.text assert "successful" in caplog.text - assert "PoissonDistributionScheduler" in caplog.text - assert "50" in caplog.text + assert "Starting phase:" in caplog.text @pytest.mark.integration def test_results_json_output( diff --git a/tests/integration/test_end_to_end_oracle.py b/tests/integration/test_end_to_end_oracle.py index 562c0238..892bcf99 100644 --- a/tests/integration/test_end_to_end_oracle.py +++ b/tests/integration/test_end_to_end_oracle.py @@ -13,7 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +"""End-to-end oracle test: verify responses match expected dataset outputs. + +Uses the async BenchmarkSession to issue all samples to a mock oracle server, +then checks each response against the expected ground-truth output. +""" + +import asyncio import random from pathlib import Path from urllib.parse import urljoin @@ -22,6 +28,7 @@ from inference_endpoint import metrics from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import LoadPattern, LoadPatternType +from inference_endpoint.core.record import EventRecord from inference_endpoint.core.types import QueryResult from inference_endpoint.dataset_manager import Dataset from inference_endpoint.dataset_manager.transforms import ( @@ -30,199 +37,157 @@ ) from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient -from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer -from inference_endpoint.load_generator import ( +from inference_endpoint.endpoint_client.http_sample_issuer import ( + HttpClientSampleIssuer, +) +from inference_endpoint.load_generator.session import ( BenchmarkSession, - MaxThroughputScheduler, - SampleEvent, - SampleEventHandler, - WithoutReplacementSampleOrder, + PhaseConfig, + PhaseType, ) -class DeepSeekR1SampleIssuer(HttpClientSampleIssuer): - def __init__(self, tmp_path: Path, url: str): - self.http_config = HTTPClientConfig( - endpoint_urls=[urljoin(url, "/v1/chat/completions")], - warmup_connections=0, - ) - super().__init__( - HTTPEndpointClient( - self.http_config, - ) - ) +class _NoOpPublisher: + def publish(self, event_record: EventRecord) -> None: + pass -async def run_benchmark(server_url, dataloader, tmp_path, rt_settings): - # Step 1. Register the complete hook to store the responses from the server. - server_responses: {str: str} = {} +async def _run_benchmark( + server_url: str, + dataloader: Dataset, + rt_settings: RuntimeSettings, +) -> tuple[dict[str, int], dict[str, str]]: + """Run a benchmark and return (uuid_to_index, responses). - def on_complete_hook(result: QueryResult): - """Callback to store the responses from the server.""" - server_responses[result.id] = result.get_response_output_string() + Uses the async BenchmarkSession with MAX_THROUGHPUT strategy. + Responses are collected via the on_sample_complete callback. + """ + loop = asyncio.get_running_loop() - SampleEventHandler.register_hook(SampleEvent.COMPLETE, on_complete_hook) - - # Step 2. Create the scheduler. - scheduler = MaxThroughputScheduler( - rt_settings, - WithoutReplacementSampleOrder, + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(server_url, "/v1/chat/completions")], + warmup_connections=0, ) - logging.info(f"Number of samples to issue: {scheduler.total_samples_to_issue}") + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) - sample_issuer = None - try: - # Step 3. Create the sample issuer. - sample_issuer = DeepSeekR1SampleIssuer(tmp_path, server_url) + responses: dict[str, str] = {} - # Step 4. Create the benchmark session. - sess = BenchmarkSession.start( + def on_complete(result: QueryResult) -> None: + responses[result.id] = result.get_response_output_string() + + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + + phases = [ + PhaseConfig( + "performance", rt_settings, dataloader, - sample_issuer, - scheduler, - name="pytest_run_benchmark", - ) + PhaseType.PERFORMANCE, + ), + ] - # Step 5. Wait for the test to end. - logging.info("Waiting for the test to end...") - sess.wait_for_test_end() - # Step 6. Return the sample UUID map and the server responses. - return sess.sample_uuid_map, server_responses + try: + result = await session.run(phases) finally: - # Step 7. Shutdown the sample issuer and the HTTP client. - if sample_issuer is not None: - sample_issuer.shutdown() - sample_issuer.http_client.shutdown() - + await http_client.shutdown_async() -""" -Test the load generator full run with a given URL. -""" + perf = result.perf_results[0] + return perf.uuid_to_index, responses -async def _run_load_generator_full_run_url( - url, dataset_path, tmp_path, clean_sample_event_hooks, hf_model_name +@pytest.mark.integration +@pytest.mark.asyncio +async def test_load_generator_full_run_mock_http_oracle_server( + mock_http_oracle_server, + ds_pickle_dataset_path, + hf_model_name, ): dummy_dataloader = Dataset.load_from_file( - dataset_path, + ds_pickle_dataset_path, transforms=[ ColumnRemap({"text_input": "prompt", "ref_output": "output"}), AddStaticColumns({"model": hf_model_name}), ], ) dummy_dataloader.load() - assert dummy_dataloader.num_samples() > 0 + n_samples = dummy_dataloader.num_samples() + assert n_samples > 0 rt_settings = RuntimeSettings( - metrics.Throughput(50), - [metrics.Throughput(50)], - min_duration_ms=1_00, - max_duration_ms=1_000, - n_samples_from_dataset=dummy_dataloader.num_samples(), - n_samples_to_issue=dummy_dataloader.num_samples(), + metrics.Throughput(5000), + [metrics.Throughput(5000)], + min_duration_ms=1_000, + max_duration_ms=10_000_000, + n_samples_from_dataset=n_samples, + n_samples_to_issue=n_samples, + min_sample_count=1, rng_sched=random.Random(1234), rng_sample_index=random.Random(1234), load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), ) - scheduler = MaxThroughputScheduler( - rt_settings, - WithoutReplacementSampleOrder, + uuid_to_index, responses = await _run_benchmark( + mock_http_oracle_server.url, dummy_dataloader, rt_settings ) - logging.info(f"Number of samples to issue: {scheduler.total_samples_to_issue}") - # Now call the benchmark - sample_uuid_map, response_cache = await run_benchmark( - url, dummy_dataloader, tmp_path, rt_settings - ) - num_responses_in_cache = len(response_cache) + + # Verify all samples received responses assert ( - num_responses_in_cache == scheduler.total_samples_to_issue - ), "Number of samples in response cache and number of samples in dataset should be the same" - vals = {} - for i in range(dummy_dataloader.num_samples()): + len(responses) == n_samples + ), f"Expected {n_samples} responses, got {len(responses)}" + + # Build expected outputs from dataset + expected: dict[int, str] = {} + for i in range(n_samples): entry = dummy_dataloader.load_sample(i) - vals[i] = entry["output"] - num_samples_in_dataset = len(vals) - logging.info(f"Number of samples in dataset: {num_samples_in_dataset}") - logging.info(f"Total samples to issue: {scheduler.total_samples_to_issue}") - logging.info(f"Request data: {num_responses_in_cache}") - - for sample_uuid, resp in response_cache.items(): - if resp is None: - logging.error(f"Sample {sample_uuid} has no response") - else: - sample_index = sample_uuid_map[sample_uuid].index - logging.info( - f"Sample {sample_uuid} should have been response {vals[sample_index][0:30]}, but was response {resp[0:30]}" - ) + expected[i] = entry["output"] + + # Verify each response matches the expected oracle output + for sample_uuid, resp in responses.items(): + sample_index = uuid_to_index[sample_uuid] + assert resp == expected[sample_index], ( + f"Sample {sample_uuid} (index={sample_index}): " + f"expected {expected[sample_index][:30]!r}, got {resp[:30]!r}" + ) -@pytest.mark.integration -@pytest.mark.asyncio -async def test_load_generator_full_run_mock_http_oracle_server( - mock_http_oracle_server, - ds_pickle_dataset_path, - tmp_path, - clean_sample_event_hooks, - hf_model_name, -): +async def _run_load_generator_full_run_url( + url: str, + dataset_path: Path, + hf_model_name: str, +) -> None: + """Helper for docker server tests.""" dummy_dataloader = Dataset.load_from_file( - ds_pickle_dataset_path, + dataset_path, transforms=[ ColumnRemap({"text_input": "prompt", "ref_output": "output"}), AddStaticColumns({"model": hf_model_name}), ], ) dummy_dataloader.load() - assert dummy_dataloader.num_samples() > 0 + n_samples = dummy_dataloader.num_samples() + assert n_samples > 0 rt_settings = RuntimeSettings( - metrics.Throughput(5000), - [metrics.Throughput(5000)], - min_duration_ms=1_000, - max_duration_ms=10_000_000, - n_samples_from_dataset=dummy_dataloader.num_samples(), - n_samples_to_issue=dummy_dataloader.num_samples(), + metrics.Throughput(50), + [metrics.Throughput(50)], + min_duration_ms=100, + max_duration_ms=1_000, + n_samples_from_dataset=n_samples, + n_samples_to_issue=n_samples, min_sample_count=1, rng_sched=random.Random(1234), rng_sample_index=random.Random(1234), load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), ) - scheduler = MaxThroughputScheduler( - rt_settings, - WithoutReplacementSampleOrder, - ) - logging.info(f"Number of samples to issue: {scheduler.total_samples_to_issue}") - - sample_uuid_map, response_cache = await run_benchmark( - mock_http_oracle_server.url, dummy_dataloader, tmp_path, rt_settings - ) - num_responses_in_cache = len(response_cache) - assert ( - num_responses_in_cache == scheduler.total_samples_to_issue - ), "Number of samples in response cache and number of samples in dataset should be the same" - vals = {} - for i in range(dummy_dataloader.num_samples()): - entry = dummy_dataloader.load_sample(i) - vals[i] = entry["output"] - num_samples_in_dataset = len(vals) - logging.info(f"Number of samples in dataset: {num_samples_in_dataset}") - logging.info(f"Total samples to issue: {scheduler.total_samples_to_issue}") - logging.info(f"Request data: {num_responses_in_cache}") - assert ( - num_samples_in_dataset == scheduler.total_samples_to_issue - ), "Number of samples in dataset and number of samples in request data should be the same" - - for sample_uuid, resp in response_cache.items(): - sample_index = sample_uuid_map["performance"][sample_uuid] - logging.info( - f"Sample {sample_uuid} should have been response {vals[sample_index][0:30]}, but was response {resp[0:30]}" - ) - assert ( - resp == vals[sample_index] - ), f"Sample {sample_uuid} should have been response {vals[sample_index][0:30]}, but was response {resp[0:30]}" + _, responses = await _run_benchmark(url, dummy_dataloader, rt_settings) + assert len(responses) == n_samples @pytest.mark.asyncio @@ -232,15 +197,11 @@ async def test_load_generator_full_run_mock_http_oracle_server( async def test_load_generator_full_run_vllm_docker_server( vllm_docker_server, ds_pickle_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ): await _run_load_generator_full_run_url( vllm_docker_server.url, ds_pickle_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ) @@ -252,15 +213,11 @@ async def test_load_generator_full_run_vllm_docker_server( async def test_load_generator_full_run_sglang_docker_server( sglang_docker_server, ds_pickle_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ): await _run_load_generator_full_run_url( sglang_docker_server.url, ds_pickle_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ) @@ -272,14 +229,10 @@ async def test_load_generator_full_run_sglang_docker_server( async def test_load_generator_full_run_trtllm_docker_server( trtllm_docker_server, ds_pickle_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ): await _run_load_generator_full_run_url( trtllm_docker_server.url, ds_pickle_dataset_path, - tmp_path, - clean_sample_event_hooks, hf_model_name, ) diff --git a/tests/performance/async_utils/transport/test_zmq.py b/tests/performance/async_utils/transport/test_zmq.py index 4df5cebc..f31baa51 100644 --- a/tests/performance/async_utils/transport/test_zmq.py +++ b/tests/performance/async_utils/transport/test_zmq.py @@ -77,7 +77,6 @@ def make_stream_chunk(payload_chars: int, idx: int) -> StreamChunk: return StreamChunk( id=str(idx), response_chunk="x" * payload_chars, - is_complete=False, ) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 7fa586e4..6e68f3cc 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -24,20 +24,13 @@ import random import string import uuid -from asyncio import Future -from concurrent.futures import ThreadPoolExecutor from pathlib import Path import zmq from inference_endpoint.core.types import ( Query, - QueryResult, - StreamChunk, - TextModelOutput, ) from inference_endpoint.dataset_manager.dataset import Dataset -from inference_endpoint.load_generator.load_generator import SampleIssuer -from inference_endpoint.load_generator.sample import SampleEventHandler def _generate_random_word( @@ -184,111 +177,3 @@ def get_test_socket_path(tmp_path: Path, test_name: str, suffix: str = "") -> st len(socket_path) <= zmq.IPC_PATH_MAX_LEN ), "socket path is too long for ZMQ IPC" return socket_path - - -class SerialSampleIssuer(SampleIssuer): - """SampleIssuer for testing. No threading, and is blocking. Whenever issue is called, - it performs the provided compute function, calling callbacks when necessary. - - The compute function should be a generator, yielding the 'chunks' of the supposed - response. - """ - - def __init__(self, compute_func=None): - if compute_func is None: - self.compute_func = lambda x: [x] - else: - self.compute_func = compute_func - - def issue(self, sample): - first = True - chunks = [] - for chunk in self.compute_func(sample.data): - chunks.append(chunk) - stream_chunk = StreamChunk( - id=sample.uuid, metadata={"first_chunk": first}, response_chunk=chunk - ) - SampleEventHandler.stream_chunk_complete(stream_chunk) - first = False - query_result = QueryResult( - id=sample.uuid, response_output=TextModelOutput(output="".join(chunks)) - ) - SampleEventHandler.query_result_complete(query_result) - - -class PooledSampleIssuer(SampleIssuer): - """SampleIssuer that has a non-blocking issue() method. Has a pool of workers which compute - the samples in parallel. - - Uses ThreadPoolExecutor to properly propagate exceptions from worker threads to the main thread. - Call check_errors() to raise any exceptions that occurred in workers and clean up completed futures. - """ - - def __init__(self, compute_func=None, n_workers: int = 4): - self.n_workers = n_workers - if compute_func is None: - self.compute_func = lambda x: [x] - else: - self.compute_func = compute_func - self.executor = ThreadPoolExecutor(max_workers=n_workers) - self.futures: list[Future[None]] = [] - - def shutdown(self, wait: bool = True): - """Shutdown the executor and wait for all tasks to complete. - - Args: - wait: Whether to wait for all tasks to complete before returning. - If False, the executor will be shutdown and the method will return immediately. - The caller is responsible for checking the futures for exceptions. (Default: True) - - Raises any exceptions that occurred in worker threads. - """ - self.executor.shutdown(wait=wait) - - if wait: - # Check all futures for exceptions - for future in self.futures: - future.result() # This will raise if the worker raised an exception - self.futures.clear() - - def handle_sample(self, sample): - first = True - chunks = [] - for chunk in self.compute_func(sample.data): - chunks.append(chunk) - stream_chunk = StreamChunk( - id=sample.uuid, metadata={"first_chunk": first}, response_chunk=chunk - ) - SampleEventHandler.stream_chunk_complete(stream_chunk) - first = False - query_result = QueryResult( - id=sample.uuid, response_output=TextModelOutput(output="".join(chunks)) - ) - SampleEventHandler.query_result_complete(query_result) - - def check_errors(self): - """Check if any worker thread has raised an exception and re-raise it. - - This checks completed futures without blocking and removes them from the list - to prevent unbounded memory growth. - """ - remaining_futures = [] - for future in self.futures: - if future.done(): - # This will raise if the worker raised an exception - future.result() - # Don't keep completed futures - else: - # Keep incomplete futures - remaining_futures.append(future) - self.futures = remaining_futures - - def issue(self, sample): - """Submit a sample to be processed by the worker pool.""" - future = self.executor.submit(self.handle_sample, sample) - self.futures.append(future) - - # Periodically clean up completed futures to prevent unbounded growth - # Check every 100 submissions to balance cleanup overhead vs memory usage - if len(self.futures) >= 100: - self.check_errors() diff --git a/tests/unit/async_utils/transport/test_zmq_context.py b/tests/unit/async_utils/transport/test_zmq_context.py index 85d58927..9eb6f38a 100644 --- a/tests/unit/async_utils/transport/test_zmq_context.py +++ b/tests/unit/async_utils/transport/test_zmq_context.py @@ -18,6 +18,7 @@ import os import tempfile import time +from pathlib import Path import pytest import zmq @@ -46,7 +47,10 @@ def test_bind_creates_temp_socket_dir(self): assert isinstance(ctx.socket_dir, str) assert os.path.isdir(ctx.socket_dir) assert "zmq_" in ctx.socket_dir - assert ctx.socket_dir.startswith(tempfile.gettempdir()) + # Socket dir is on /dev/shm if available, otherwise tempdir + shm = Path("/dev/shm") + expected_parent = str(shm) if shm.is_dir() else tempfile.gettempdir() + assert ctx.socket_dir.startswith(expected_parent) def test_scoped_accepts_io_threads(self): """scoped(io_threads=N) creates context without error.""" diff --git a/tests/unit/async_utils/transport/zmq/test_ready_check.py b/tests/unit/async_utils/transport/zmq/test_ready_check.py index e90d3858..17bdbeb9 100644 --- a/tests/unit/async_utils/transport/zmq/test_ready_check.py +++ b/tests/unit/async_utils/transport/zmq/test_ready_check.py @@ -82,13 +82,15 @@ async def test_close_idempotent(self): receiver.close() receiver.close() - async def test_close_on_timeout(self): + async def test_socket_survives_timeout(self): + """Socket must NOT be closed on timeout — caller may retry.""" with tempfile.TemporaryDirectory() as tmpdir: with ManagedZMQContext.scoped(socket_dir=tmpdir) as ctx: receiver = ReadyCheckReceiver("ready_close_timeout", ctx, count=1) with pytest.raises(TimeoutError): await receiver.wait(timeout=0.1) - assert receiver._sock.closed + assert not receiver._sock.closed + receiver.close() def _child_send_ready(socket_dir: str, path: str, identity: int) -> None: diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index 7ebdc1ab..a91f878b 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -45,6 +45,7 @@ from inference_endpoint.config.schema import ( OnlineBenchmarkConfig as OnlineConfig, ) +from inference_endpoint.config.utils import cli_error_formatter as _error_formatter from inference_endpoint.core.types import QueryResult from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.exceptions import InputValidationError @@ -414,10 +415,6 @@ class TestErrorFormatter: @pytest.mark.unit def test_cyclopts_arg_with_children(self): - from inference_endpoint.config.utils import ( - cli_error_formatter as _error_formatter, - ) - child = SimpleNamespace( name="--endpoints", names=("--endpoints",), required=True, has_tokens=False ) @@ -429,10 +426,6 @@ def test_cyclopts_arg_with_children(self): @pytest.mark.unit def test_cyclopts_leaf_arg(self): - from inference_endpoint.config.utils import ( - cli_error_formatter as _error_formatter, - ) - arg = SimpleNamespace( name="--model", names=("--model-params.name", "--model"), children=[] ) @@ -443,10 +436,6 @@ def test_cyclopts_leaf_arg(self): @pytest.mark.unit def test_pydantic_validation_error(self): - from inference_endpoint.config.utils import ( - cli_error_formatter as _error_formatter, - ) - try: BenchmarkConfig( type=TestType.OFFLINE, @@ -461,10 +450,6 @@ def test_pydantic_validation_error(self): @pytest.mark.unit def test_generic_error_fallback(self): - from inference_endpoint.config.utils import ( - cli_error_formatter as _error_formatter, - ) - class FakeError: argument = None __cause__ = None diff --git a/tests/unit/commands/test_util_commands.py b/tests/unit/commands/test_util_commands.py index 96503c12..54fe7ca9 100644 --- a/tests/unit/commands/test_util_commands.py +++ b/tests/unit/commands/test_util_commands.py @@ -15,6 +15,7 @@ """Tests for utility commands (info, validate, init, probe) and main.py dispatch.""" +import asyncio from pathlib import Path from unittest.mock import MagicMock, patch @@ -22,7 +23,7 @@ from inference_endpoint import __version__ from inference_endpoint.commands.info import execute_info from inference_endpoint.commands.init import execute_init -from inference_endpoint.commands.probe import ProbeConfig, execute_probe +from inference_endpoint.commands.probe import ProbeConfig, _probe_async, execute_probe from inference_endpoint.commands.validate import execute_validate from inference_endpoint.config.schema import APIType from inference_endpoint.exceptions import ( @@ -31,6 +32,7 @@ InputValidationError, SetupError, ) +from inference_endpoint.main import run class TestInfoCommand: @@ -176,10 +178,6 @@ def test_execute_probe_calls_async(self, mock_run_async): def test_empty_model_raises(self): config = ProbeConfig(endpoints="http://localhost:8000", model="") with pytest.raises(InputValidationError, match="Model required"): - import asyncio - - from inference_endpoint.commands.probe import _probe_async - asyncio.run(_probe_async(config)) @pytest.mark.unit @@ -189,10 +187,6 @@ def test_setup_failure_raises(self, mock_client_cls): config = ProbeConfig(endpoints="http://localhost:8000", model="test") with pytest.raises(SetupError, match="Probe setup failed"): - import asyncio - - from inference_endpoint.commands.probe import _probe_async - asyncio.run(_probe_async(config)) @pytest.mark.unit @@ -206,10 +200,6 @@ def test_all_issues_fail_raises(self, mock_client_cls): endpoints="http://localhost:8000", model="test", requests=2 ) with pytest.raises(ExecutionError, match="no queries could be issued"): - import asyncio - - from inference_endpoint.commands.probe import _probe_async - asyncio.run(_probe_async(config)) @@ -229,8 +219,6 @@ class TestMainRunExceptionHandling: ], ) def test_exception_exit_codes(self, exc, code): - from inference_endpoint.main import run - with patch("inference_endpoint.main.app") as mock_app: mock_app.meta.side_effect = exc with pytest.raises(SystemExit) as exc_info: diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index b33f7eda..52bdbe77 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -479,7 +479,6 @@ def test_stream_chunk_minimal(self): assert decoded.id == "" assert decoded.response_chunk == "" - assert decoded.is_complete is False assert decoded.metadata == {} def test_stream_chunk_with_basic_content(self): @@ -493,14 +492,12 @@ def test_stream_chunk_with_basic_content(self): assert decoded.id == "query-123" assert decoded.response_chunk == "Hello, this is a chunk of text." - assert decoded.is_complete is False def test_stream_chunk_first_chunk(self): """Test StreamChunk representing first chunk with metadata.""" chunk = StreamChunk( id="query-456", response_chunk="First token", - is_complete=False, metadata={"first_chunk": True, "latency_ns": 1234567}, ) @@ -510,24 +507,11 @@ def test_stream_chunk_first_chunk(self): assert decoded.metadata["first_chunk"] is True assert decoded.metadata["latency_ns"] == 1234567 - def test_stream_chunk_final_chunk(self): - """Test StreamChunk representing final chunk.""" - chunk = StreamChunk( - id="query-789", response_chunk="Final text.", is_complete=True - ) - - encoded = msgspec.msgpack.encode(chunk) - decoded = msgspec.msgpack.decode(encoded, type=StreamChunk) - - assert decoded.is_complete is True - assert decoded.response_chunk == "Final text." - def test_stream_chunk_with_comprehensive_metadata(self): """Test StreamChunk with detailed metadata.""" chunk = StreamChunk( id="query-meta", response_chunk=" next token", - is_complete=False, metadata={ "model": "llama-2-70b", "chunk_index": 5, @@ -569,7 +553,6 @@ def test_stream_chunk_all_fields_populated(self): chunk = StreamChunk( id="query-full-chunk", response_chunk="Complete chunk text", - is_complete=True, metadata={ "model": "gpt-4", "finish_reason": "stop", @@ -582,7 +565,6 @@ def test_stream_chunk_all_fields_populated(self): assert decoded.id == "query-full-chunk" assert decoded.response_chunk == "Complete chunk text" - assert decoded.is_complete is True assert decoded.metadata["finish_reason"] == "stop" def test_stream_chunk_multiple_roundtrips(self): @@ -590,7 +572,6 @@ def test_stream_chunk_multiple_roundtrips(self): original = StreamChunk( id="query-roundtrip", response_chunk="Test chunk", - is_complete=False, metadata={"index": 1}, ) @@ -605,7 +586,6 @@ def test_stream_chunk_multiple_roundtrips(self): # Verify all fields remain consistent assert decoded2.id == original.id assert decoded2.response_chunk == original.response_chunk - assert decoded2.is_complete == original.is_complete assert decoded2.metadata == original.metadata @@ -811,7 +791,7 @@ def test_serialize_list_of_stream_chunks(self): id="q1", response_chunk="First", metadata={"first_chunk": True} ), StreamChunk(id="q1", response_chunk=" second"), - StreamChunk(id="q1", response_chunk=" final", is_complete=True), + StreamChunk(id="q1", response_chunk=" final"), ] encoded = msgspec.msgpack.encode(chunks) @@ -819,7 +799,6 @@ def test_serialize_list_of_stream_chunks(self): assert len(decoded) == 3 assert decoded[0].metadata.get("first_chunk") is True - assert decoded[2].is_complete is True def test_query_result_with_nested_metadata(self): """Test QueryResult with deeply nested metadata and TextModelOutput.""" diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py new file mode 100644 index 00000000..76b3d10a --- /dev/null +++ b/tests/unit/load_generator/test_async_session.py @@ -0,0 +1,873 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the async BenchmarkSession.""" + +from __future__ import annotations + +import asyncio +import random + +import pytest +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import LoadPattern, LoadPatternType +from inference_endpoint.core.record import ( + ErrorEventType, + EventRecord, + SampleEventType, + SessionEventType, +) +from inference_endpoint.core.types import ErrorData, Query, QueryResult, StreamChunk +from inference_endpoint.dataset_manager.dataset import Dataset +from inference_endpoint.load_generator.session import ( + BenchmarkSession, + PhaseConfig, + PhaseIssuer, + PhaseResult, + PhaseType, + SessionResult, +) +from inference_endpoint.metrics.metric import Throughput + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + + +class FakeDataset(Dataset): + """In-memory dataset for tests.""" + + def __init__(self, n_samples: int = 10): + self._n = n_samples + + def load_sample(self, index: int) -> dict: + return {"prompt": f"sample_{index}", "model": "test"} + + def num_samples(self) -> int: + return self._n + + +class FakeIssuer: + """Fake SampleIssuer that queues responses for controlled delivery.""" + + def __init__(self, response_delay: float = 0.001): + self._issued: list[Query] = [] + self._response_queue: asyncio.Queue[QueryResult | StreamChunk | None] = ( + asyncio.Queue() + ) + self._response_delay = response_delay + self._auto_respond = True + self._loop: asyncio.AbstractEventLoop | None = None + + def issue(self, query: Query) -> None: + self._issued.append(query) + if self._auto_respond and self._loop: + + def _enqueue_response(q: Query = query) -> None: + self._response_queue.put_nowait( + QueryResult(id=q.id, response_output=None) + ) + + self._loop.call_later(self._response_delay, _enqueue_response) + + async def recv(self) -> QueryResult | StreamChunk | None: + return await self._response_queue.get() + + def shutdown(self) -> None: + self._response_queue.put_nowait(None) + + def inject_response(self, resp: QueryResult | StreamChunk) -> None: + self._response_queue.put_nowait(resp) + + @property + def issued_queries(self) -> list[Query]: + return self._issued + + +class FakePublisher: + """Captures published EventRecords.""" + + def __init__(self): + self.events: list[EventRecord] = [] + + def publish(self, event_record: EventRecord) -> None: + self.events.append(event_record) + + def events_of_type(self, event_type) -> list[EventRecord]: + return [e for e in self.events if e.event_type == event_type] + + +def _make_settings( + load_pattern: LoadPattern | None = None, + n_samples: int = 10, + max_duration_ms: int | None = None, +) -> RuntimeSettings: + return RuntimeSettings( + metric_target=Throughput(100), + reported_metrics=[], + min_duration_ms=0, + max_duration_ms=max_duration_ms, + n_samples_from_dataset=n_samples, + n_samples_to_issue=n_samples, + min_sample_count=n_samples, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=load_pattern or LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + +# --------------------------------------------------------------------------- +# PhaseIssuer tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestPhaseIssuer: + def test_issue_builds_query_and_publishes(self): + dataset = FakeDataset(5) + issuer = FakeIssuer() + issuer._auto_respond = False + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: False) + + result = phase_issuer.issue(3) + assert result is not None + assert phase_issuer.issued_count == 1 + assert phase_issuer.inflight == 1 + assert len(issuer.issued_queries) == 1 + assert issuer.issued_queries[0].id == result + assert 3 in phase_issuer.uuid_to_index.values() + + # Should have published ISSUED event + issued_events = publisher.events_of_type(SampleEventType.ISSUED) + assert len(issued_events) == 1 + assert issued_events[0].sample_uuid == result + + def test_issue_returns_none_when_stopped(self): + dataset = FakeDataset(5) + issuer = FakeIssuer() + issuer._auto_respond = False + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: True) + + result = phase_issuer.issue(0) + assert result is None + assert phase_issuer.issued_count == 0 + + def test_uuid_is_unique_per_issue(self): + dataset = FakeDataset(5) + issuer = FakeIssuer() + issuer._auto_respond = False + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: False) + + ids = [phase_issuer.issue(i % 5) for i in range(10)] + assert len(set(ids)) == 10 + + +# --------------------------------------------------------------------------- +# BenchmarkSession tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBenchmarkSession: + @pytest.mark.asyncio + async def test_single_perf_phase(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig("perf", _make_settings(n_samples=5), FakeDataset(5)), + ] + result = await session.run(phases) + + assert len(result.phase_results) == 1 + assert result.perf_results[0].name == "perf" + assert result.perf_results[0].issued_count == 5 + assert len(result.perf_results[0].uuid_to_index) == 5 + + # Check session events + started = publisher.events_of_type(SessionEventType.STARTED) + ended = publisher.events_of_type(SessionEventType.ENDED) + start_track = publisher.events_of_type( + SessionEventType.START_PERFORMANCE_TRACKING + ) + stop_track = publisher.events_of_type( + SessionEventType.STOP_PERFORMANCE_TRACKING + ) + assert len(started) == 1 + assert len(ended) == 1 + assert len(start_track) == 1 + assert len(stop_track) == 1 + + @pytest.mark.asyncio + async def test_accuracy_phase(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig( + "acc", _make_settings(n_samples=3), FakeDataset(3), PhaseType.ACCURACY + ), + ] + result = await session.run(phases) + + assert len(result.accuracy_results) == 1 + assert result.accuracy_results[0].issued_count == 3 + # No tracking events for accuracy + assert ( + len(publisher.events_of_type(SessionEventType.START_PERFORMANCE_TRACKING)) + == 0 + ) + + @pytest.mark.asyncio + async def test_saturation_produces_no_result(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig( + "warmup", + _make_settings(n_samples=3), + FakeDataset(3), + PhaseType.SATURATION, + ), + ] + result = await session.run(phases) + assert len(result.phase_results) == 0 + + @pytest.mark.asyncio + async def test_multi_phase(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig( + "warmup", + _make_settings(n_samples=2), + FakeDataset(2), + PhaseType.SATURATION, + ), + PhaseConfig( + "perf", + _make_settings(n_samples=5), + FakeDataset(5), + PhaseType.PERFORMANCE, + ), + PhaseConfig( + "acc", _make_settings(n_samples=3), FakeDataset(3), PhaseType.ACCURACY + ), + ] + result = await session.run(phases) + + assert len(result.perf_results) == 1 + assert result.perf_results[0].issued_count == 5 + assert len(result.accuracy_results) == 1 + assert result.accuracy_results[0].issued_count == 3 + + @pytest.mark.asyncio + async def test_stop_terminates_early(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + + # Stop after a short delay + loop.call_later(0.05, session.stop) + + phases = [ + PhaseConfig( + "perf", + _make_settings(n_samples=100_000, max_duration_ms=10_000), + FakeDataset(100), + ), + ] + result = await session.run(phases) + # Should have stopped early, not issued all 100k + assert result.perf_results[0].issued_count < 100_000 + + @pytest.mark.asyncio + async def test_on_sample_complete_callback(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + completed: list[str] = [] + + def on_complete(result: QueryResult) -> None: + completed.append(result.id) + + session = BenchmarkSession( + issuer, publisher, loop, on_sample_complete=on_complete + ) + phases = [ + PhaseConfig("perf", _make_settings(n_samples=5), FakeDataset(5)), + ] + await session.run(phases) + assert len(completed) == 5 + + @pytest.mark.asyncio + async def test_stale_completions_ignored_by_strategy(self): + """Responses from saturation phase should not affect perf phase strategy.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + # Issuer that delays responses significantly so they arrive in next phase + issuer = FakeIssuer(response_delay=0.1) + issuer._loop = loop + + session = BenchmarkSession(issuer, publisher, loop) + + concurrency_settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=2 + ), + n_samples=3, + ) + phases = [ + PhaseConfig( + "sat", _make_settings(n_samples=2), FakeDataset(2), PhaseType.SATURATION + ), + PhaseConfig( + "perf", concurrency_settings, FakeDataset(3), PhaseType.PERFORMANCE + ), + ] + result = await session.run(phases) + + # Perf phase should complete with its own samples, not be confused by stale ones + assert len(result.perf_results) == 1 + assert result.perf_results[0].issued_count == 3 + + @pytest.mark.asyncio + async def test_recv_none_triggers_stop(self): + """If issuer.recv() returns None mid-phase, drain should abort quickly.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig("perf", _make_settings(n_samples=5), FakeDataset(5)), + ] + + # Schedule transport close after a short delay — recv returns None + loop.call_later(0.05, issuer.shutdown) + + # Session should complete quickly — recv None sets stop_requested, + # which aborts drain. wait_for prevents CI hang if this regresses. + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + assert result is not None + + @pytest.mark.asyncio + async def test_streaming_query_completes_via_queryresult(self): + """Streaming: StreamChunks publish timing events, QueryResult handles completion. + + The worker sends StreamChunk(first) → StreamChunk(delta) → QueryResult. + Only the QueryResult decrements inflight and releases the concurrency + semaphore. StreamChunks only publish timing events. + """ + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + + settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=1 + ), + n_samples=2, + ) + phases = [PhaseConfig("perf", settings, FakeDataset(2))] + + async def inject_streaming_responses(): + """Simulate worker: StreamChunk(first) → StreamChunk(delta) → QueryResult.""" + while not issuer._issued: + await asyncio.sleep(0.005) + q1 = issuer._issued[0] + issuer.inject_response( + StreamChunk(id=q1.id, metadata={"first_chunk": True}) + ) + issuer.inject_response(StreamChunk(id=q1.id, response_chunk="more")) + issuer.inject_response(QueryResult(id=q1.id, response_output="out1")) + while len(issuer._issued) < 2: + await asyncio.sleep(0.005) + q2 = issuer._issued[1] + issuer.inject_response( + StreamChunk(id=q2.id, metadata={"first_chunk": True}) + ) + issuer.inject_response(StreamChunk(id=q2.id, response_chunk="more")) + issuer.inject_response(QueryResult(id=q2.id, response_output="out2")) + + injector = asyncio.create_task(inject_streaming_responses()) + result = await asyncio.wait_for(session.run(phases), timeout=5.0) + await injector + + assert result.perf_results[0].issued_count == 2 + recv_first = publisher.events_of_type(SampleEventType.RECV_FIRST) + assert len(recv_first) == 2 + + @pytest.mark.asyncio + async def test_concurrency_strategy_transport_close_no_deadlock(self): + """ConcurrencyStrategy must not deadlock when transport closes mid-phase.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer(response_delay=999) # Responses never arrive in time + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=2 + ), + n_samples=100, + ) + phases = [PhaseConfig("perf", settings, FakeDataset(10))] + + # Close transport after strategy issues initial batch and blocks on semaphore + loop.call_later(0.1, issuer.shutdown) + + # Must complete without deadlock — wait_for prevents CI hang + result = await asyncio.wait_for(session.run(phases), timeout=5.0) + assert result is not None + + @pytest.mark.asyncio + async def test_on_sample_complete_called_for_streaming_query(self): + """on_sample_complete fires exactly once per streaming query (on QueryResult). + + StreamChunks only publish timing events — callback fires only for QueryResult. + """ + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + completed: list[QueryResult | StreamChunk] = [] + + def on_complete(result: QueryResult | StreamChunk) -> None: + completed.append(result) + + session = BenchmarkSession( + issuer, publisher, loop, on_sample_complete=on_complete + ) + settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=1 + ), + n_samples=1, + ) + phases = [PhaseConfig("perf", settings, FakeDataset(1))] + + async def inject(): + while not issuer._issued: + await asyncio.sleep(0.005) + q = issuer._issued[0] + issuer.inject_response(StreamChunk(id=q.id, metadata={"first_chunk": True})) + issuer.inject_response(StreamChunk(id=q.id, response_chunk="more")) + issuer.inject_response(QueryResult(id=q.id, response_output="done")) + + asyncio.create_task(inject()) + await asyncio.wait_for(session.run(phases), timeout=5.0) + + assert len(completed) == 1 + assert isinstance(completed[0], QueryResult) + + @pytest.mark.asyncio + async def test_failed_query_published_as_error_event(self): + """Bug #5: QueryResult with error should publish ErrorEventType, not just COMPLETE.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + settings = _make_settings(n_samples=1) + phases = [PhaseConfig("perf", settings, FakeDataset(1))] + + async def inject_error(): + while not issuer._issued: + await asyncio.sleep(0.005) + q = issuer._issued[0] + issuer.inject_response( + QueryResult( + id=q.id, + error=ErrorData(error_type="timeout", error_message="timed out"), + ) + ) + + asyncio.create_task(inject_error()) + await asyncio.wait_for(session.run(phases), timeout=5.0) + + # Should have published both COMPLETE and an error event + complete_events = publisher.events_of_type(SampleEventType.COMPLETE) + error_events = [ + e for e in publisher.events if isinstance(e.event_type, ErrorEventType) + ] + assert len(complete_events) == 1 + # Bug #5: error event should also be published + assert len(error_events) == 1 + + +@pytest.mark.unit +class TestBenchmarkSessionPoissonIntegration: + """Poisson strategy (TimedIssueStrategy) integration with session.""" + + @pytest.mark.asyncio + async def test_poisson_issues_all_samples(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + poisson_settings = _make_settings( + load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=5000.0), + n_samples=8, + ) + phases = [ + PhaseConfig("perf", poisson_settings, FakeDataset(8)), + ] + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + + assert len(result.perf_results) == 1 + assert result.perf_results[0].issued_count == 8 + + @pytest.mark.asyncio + async def test_poisson_respects_stop(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + poisson_settings = _make_settings( + load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=100.0), + n_samples=100_000, + max_duration_ms=60_000, + ) + phases = [ + PhaseConfig("perf", poisson_settings, FakeDataset(100)), + ] + loop.call_later(0.05, session.stop) + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + assert result.perf_results[0].issued_count < 100_000 + + +@pytest.mark.unit +class TestBenchmarkSessionMaxDuration: + """max_duration_ms timeout: phase stops after duration even with samples remaining.""" + + @pytest.mark.asyncio + async def test_max_duration_stops_phase(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + # Very short max_duration with many samples to issue + settings = _make_settings( + load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=10.0), + n_samples=100_000, + max_duration_ms=50, + ) + phases = [PhaseConfig("perf", settings, FakeDataset(100))] + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + + # Should have stopped well before issuing all samples + assert result.perf_results[0].issued_count < 100_000 + + @pytest.mark.asyncio + async def test_max_duration_with_burst(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + settings = _make_settings(n_samples=1_000_000, max_duration_ms=20) + phases = [PhaseConfig("perf", settings, FakeDataset(100))] + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + + # Burst fires fast, but stop_check should cut it short + assert result.perf_results[0].issued_count < 1_000_000 + + +@pytest.mark.unit +class TestBenchmarkSessionAccuracyErrorHandling: + """Error handling in accuracy phase: query fails, verify it doesn't corrupt scoring.""" + + @pytest.mark.asyncio + async def test_failed_query_in_accuracy_phase_preserves_uuid_map(self): + loop = asyncio.get_running_loop() + publisher = FakePublisher() + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + completed_results: list[QueryResult | StreamChunk] = [] + + def on_complete(result: QueryResult | StreamChunk) -> None: + completed_results.append(result) + + session = BenchmarkSession( + issuer, publisher, loop, on_sample_complete=on_complete + ) + settings = _make_settings(n_samples=3) + phases = [ + PhaseConfig("acc", settings, FakeDataset(3), PhaseType.ACCURACY), + ] + + async def inject_mixed_responses(): + while len(issuer._issued) < 3: + await asyncio.sleep(0.005) + # First query: success + issuer.inject_response( + QueryResult(id=issuer._issued[0].id, response_output="answer1") + ) + # Second query: error + issuer.inject_response( + QueryResult( + id=issuer._issued[1].id, + error=ErrorData(error_type="timeout", error_message="timed out"), + ) + ) + # Third query: success + issuer.inject_response( + QueryResult(id=issuer._issued[2].id, response_output="answer3") + ) + + asyncio.create_task(inject_mixed_responses()) + result = await asyncio.wait_for(session.run(phases), timeout=5.0) + + assert len(result.accuracy_results) == 1 + acc = result.accuracy_results[0] + # All 3 samples should be in uuid_to_index, including the failed one + assert acc.issued_count == 3 + assert len(acc.uuid_to_index) == 3 + # on_sample_complete should have fired for all 3 + assert len(completed_results) == 3 + + @pytest.mark.asyncio + async def test_error_event_published_in_accuracy_phase(self): + loop = asyncio.get_running_loop() + publisher = FakePublisher() + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + session = BenchmarkSession(issuer, publisher, loop) + settings = _make_settings(n_samples=1) + phases = [ + PhaseConfig("acc", settings, FakeDataset(1), PhaseType.ACCURACY), + ] + + async def inject_error(): + while not issuer._issued: + await asyncio.sleep(0.005) + issuer.inject_response( + QueryResult( + id=issuer._issued[0].id, + error=ErrorData(error_type="server_error", error_message="500"), + ) + ) + + asyncio.create_task(inject_error()) + await asyncio.wait_for(session.run(phases), timeout=5.0) + + error_events = [ + e for e in publisher.events if isinstance(e.event_type, ErrorEventType) + ] + assert len(error_events) == 1 + + +@pytest.mark.unit +class TestBenchmarkSessionMultiPhaseSatPerfSequence: + """Multi-perf + saturation sequence (sat -> perf -> sat -> perf).""" + + @pytest.mark.asyncio + async def test_sat_perf_sat_perf(self): + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + issuer._loop = loop + publisher = FakePublisher() + + session = BenchmarkSession(issuer, publisher, loop) + phases = [ + PhaseConfig( + "warmup1", + _make_settings(n_samples=2), + FakeDataset(2), + PhaseType.SATURATION, + ), + PhaseConfig( + "perf1", + _make_settings(n_samples=4), + FakeDataset(4), + PhaseType.PERFORMANCE, + ), + PhaseConfig( + "warmup2", + _make_settings(n_samples=3), + FakeDataset(3), + PhaseType.SATURATION, + ), + PhaseConfig( + "perf2", + _make_settings(n_samples=6), + FakeDataset(6), + PhaseType.PERFORMANCE, + ), + ] + result = await asyncio.wait_for(session.run(phases), timeout=10.0) + + # Both perf phases should produce results + assert len(result.perf_results) == 2 + assert result.perf_results[0].name == "perf1" + assert result.perf_results[0].issued_count == 4 + assert result.perf_results[1].name == "perf2" + assert result.perf_results[1].issued_count == 6 + + # Saturation phases produce no results + assert len(result.phase_results) == 2 + + # Should have start/stop tracking for each perf phase + start_track = publisher.events_of_type( + SessionEventType.START_PERFORMANCE_TRACKING + ) + stop_track = publisher.events_of_type( + SessionEventType.STOP_PERFORMANCE_TRACKING + ) + assert len(start_track) == 2 + assert len(stop_track) == 2 + + +@pytest.mark.unit +class TestBenchmarkSessionStaleStreamChunk: + """Stale StreamChunk from previous phase is ignored.""" + + @pytest.mark.asyncio + async def test_stale_stream_chunk_ignored(self): + """StreamChunk from saturation phase should not affect perf phase counts.""" + loop = asyncio.get_running_loop() + publisher = FakePublisher() + + issuer = FakeIssuer() + issuer._loop = loop + issuer._auto_respond = False + + completed: list[str] = [] + + def on_complete(result: QueryResult | StreamChunk) -> None: + completed.append(result.id) + + session = BenchmarkSession( + issuer, publisher, loop, on_sample_complete=on_complete + ) + + # Saturation with slow responses, perf with concurrency + sat_settings = _make_settings(n_samples=2) + perf_settings = _make_settings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=1 + ), + n_samples=2, + ) + + phases = [ + PhaseConfig("sat", sat_settings, FakeDataset(2), PhaseType.SATURATION), + PhaseConfig("perf", perf_settings, FakeDataset(2), PhaseType.PERFORMANCE), + ] + + async def inject_responses(): + # Wait for saturation queries + while len(issuer._issued) < 2: + await asyncio.sleep(0.005) + sat_ids = [q.id for q in issuer._issued[:2]] + + # Wait for perf phase queries to start + while len(issuer._issued) < 3: + await asyncio.sleep(0.005) + + # Inject stale StreamChunks from saturation phase into perf phase + issuer.inject_response(StreamChunk(id=sat_ids[0], response_chunk="stale")) + issuer.inject_response(StreamChunk(id=sat_ids[1], response_chunk="stale")) + + # Now complete the perf queries + perf_queries = issuer._issued[2:] + for q in perf_queries: + issuer.inject_response(QueryResult(id=q.id, response_output="ok")) + # Wait for second perf query if not yet issued + while len(issuer._issued) < 4: + await asyncio.sleep(0.005) + for q in issuer._issued[2:]: + if q.id not in list(completed): + issuer.inject_response( + QueryResult(id=q.id, response_output="ok") + ) + + asyncio.create_task(inject_responses()) + result = await asyncio.wait_for(session.run(phases), timeout=5.0) + + # Perf phase should have exactly 2 issued samples + assert len(result.perf_results) == 1 + assert result.perf_results[0].issued_count == 2 + # on_sample_complete should only be called for perf-phase queries + # (stale sat queries are not in perf's uuid_to_index) + for cid in completed: + assert cid in result.perf_results[0].uuid_to_index + + +@pytest.mark.unit +class TestSessionResult: + def test_perf_results_filter(self): + results = [ + PhaseResult("sat", PhaseType.SATURATION, {}, 0, 0, 0), + PhaseResult("perf1", PhaseType.PERFORMANCE, {"a": 1}, 10, 0, 100), + PhaseResult("perf2", PhaseType.PERFORMANCE, {"b": 2}, 20, 100, 200), + PhaseResult("acc", PhaseType.ACCURACY, {"c": 3}, 5, 200, 300), + ] + sr = SessionResult("test", results, 0, 300) + assert len(sr.perf_results) == 2 + assert len(sr.accuracy_results) == 1 + assert sr.perf_results[0].name == "perf1" diff --git a/tests/unit/load_generator/test_load_generator.py b/tests/unit/load_generator/test_load_generator.py deleted file mode 100644 index 5fd4ae4b..00000000 --- a/tests/unit/load_generator/test_load_generator.py +++ /dev/null @@ -1,279 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import time -from collections import defaultdict -from unittest.mock import patch - -import inference_endpoint.metrics as metrics -from inference_endpoint.config.runtime_settings import RuntimeSettings -from inference_endpoint.config.schema import LoadPattern, LoadPatternType -from inference_endpoint.core.types import QueryResult, StreamChunk -from inference_endpoint.load_generator.load_generator import ( - SampleIssuer, - SchedulerBasedLoadGenerator, -) -from inference_endpoint.load_generator.sample import SampleEvent, SampleEventHandler -from inference_endpoint.load_generator.scheduler import ( - MaxThroughputScheduler, - PoissonDistributionScheduler, - SampleOrder, - WithoutReplacementSampleOrder, -) - -from tests.test_helpers import DummyDataLoader, SerialSampleIssuer - - -class FibonacciSampleOrder(SampleOrder): - """Sample order where the corresponding value for a sample index is that number value in - the Fibonacci sequence. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.a = 0 - self.b = 1 - - def next_sample_index(self) -> int: - retval = self.a - c = self.a + self.b - self.a = self.b - self.b = c - return retval - - -@patch( - "inference_endpoint.load_generator.load_generator.LoadGenerator.load_sample_data" -) -def test_load_generator(load_sample_data_mock, max_throughput_runtime_settings): - load_sample_data_mock.side_effect = lambda index, _uuid: index**2 - - class ListAppendIssuer(SampleIssuer): - def __init__(self): - self.issued = [] - - def issue(self, sample): - self.issued.append(sample) - - fake_sample_issuer = ListAppendIssuer() - - load_generator = SchedulerBasedLoadGenerator( - fake_sample_issuer, - None, # No Dataloader needed — load_sample_data is mocked - scheduler=MaxThroughputScheduler( - max_throughput_runtime_settings, - FibonacciSampleOrder, - ), - ) - a = 0 - b = 1 - for i, issued_sample in enumerate(load_generator): - assert issued_sample.sample.data == a**2 - assert issued_sample.sample == fake_sample_issuer.issued[i] - assert len(fake_sample_issuer.issued) == i + 1 - - c = a + b - a = b - b = c - - -def test_full_run(): - rt_settings = RuntimeSettings( - metrics.Throughput(5000), - [metrics.Throughput(5000)], - min_duration_ms=1000, - max_duration_ms=10_000, - n_samples_from_dataset=100, - n_samples_to_issue=10_000, - min_sample_count=100, - rng_sched=random.Random(1234), - rng_sample_index=random.Random(1234), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - - def compute_digits_of_square(n: int): - yield from str(n**2) - - sample_issuer = SerialSampleIssuer(compute_digits_of_square) - load_generator = SchedulerBasedLoadGenerator( - sample_issuer, - DummyDataLoader(100), - scheduler=MaxThroughputScheduler( - rt_settings, - WithoutReplacementSampleOrder, - ), - ) - - # Hooks for chunk data and query results - received_chunks = defaultdict(list) - - def save_chunk(chunk: StreamChunk): - received_chunks[chunk.id].append(chunk.response_chunk) - - SampleEventHandler.register_hook(SampleEvent.FIRST_CHUNK, save_chunk) - SampleEventHandler.register_hook(SampleEvent.NON_FIRST_CHUNK, save_chunk) - - results = {} - - def save_query_result(result: QueryResult): - results[result.id] = result.get_response_output_string() - - SampleEventHandler.register_hook(SampleEvent.COMPLETE, save_query_result) - - sent_hist = defaultdict(int) - sent_uuids = defaultdict(list) - seen_uuids = set() - for issued_sample in load_generator: - # The test issuer is serial, so we can confirm that a sample is completed before the next - # is issued. - expected = str(issued_sample.index**2) - assert received_chunks[issued_sample.sample.uuid][0] == expected[0] - assert len(received_chunks[issued_sample.sample.uuid]) == len(expected) - assert "".join(received_chunks[issued_sample.sample.uuid]) == expected - assert results[issued_sample.sample.uuid] == expected - - sent_hist[issued_sample.index] += 1 - sent_uuids[issued_sample.index].append(issued_sample.sample.uuid) - seen_uuids.add(issued_sample.sample.uuid) - - # WithoutReplacementSampleOrder should ensure that as long as total # of samples issued is a multiple of dataset size, - # the number of issues per sample is the same - target_issues = rt_settings.n_samples_to_issue // rt_settings.n_samples_from_dataset - for index, n_sent in sent_hist.items(): - assert ( - n_sent == target_issues - ), f"Sample {index} should have been issued {target_issues} times, but was issued {n_sent} times" - - # Check uuid uniqueness - n_distinct_uuids = len(set(sent_uuids[index])) - assert ( - n_distinct_uuids == n_sent - ), f"Sample {index} should have {n_sent} unique uuids, but has {n_distinct_uuids}" - - # Check that ALL uuids are unique - assert ( - len(seen_uuids) == rt_settings.n_samples_to_issue - ), f"Should have seen {rt_settings.n_samples_to_issue} unique uuids, but saw {len(seen_uuids)}" - - -@patch( - "inference_endpoint.load_generator.load_generator.LoadGenerator.load_sample_data" -) -def test_max_duration_ms_stops_issuance(load_sample_data_mock): - """max_duration_ms should stop iteration before n_samples_to_issue is exhausted.""" - load_sample_data_mock.side_effect = lambda index, _uuid: index - - max_duration_ms = 50 - rt_settings = RuntimeSettings( - metrics.Throughput(5000), - reported_metrics=[], - min_duration_ms=0, - max_duration_ms=max_duration_ms, - n_samples_from_dataset=100, - n_samples_to_issue=1_000_000, - min_sample_count=100, - rng_sched=random.Random(42), - rng_sample_index=random.Random(42), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - - issued_count = 0 - - class CountingIssuer(SampleIssuer): - def issue(self, sample): - pass - - load_generator = SchedulerBasedLoadGenerator( - CountingIssuer(), - None, - scheduler=MaxThroughputScheduler(rt_settings, WithoutReplacementSampleOrder), - ) - - start = time.monotonic() - for _ in load_generator: - issued_count += 1 - elapsed_s = time.monotonic() - start - - # Should have stopped well before issuing 1,000,000 samples - assert ( - issued_count < 1_000_000 - ), f"Expected timeout to stop issuance, but {issued_count} samples were issued" - # Elapsed wall-clock should be reasonably close to max_duration_ms: - # lower bound ensures the timeout (not an early exit) was responsible for stopping, - # upper bound is generous to accommodate slow CI runners. - max_duration_s = max_duration_ms / 1000 - assert ( - elapsed_s >= max_duration_s * 0.5 - ), f"Elapsed time {elapsed_s:.3f}s is unexpectedly below max_duration_ms={max_duration_ms}ms" - assert ( - elapsed_s < max_duration_s * 2 - ), f"Elapsed time {elapsed_s:.3f}s far exceeds max_duration_ms={max_duration_ms}ms" - - -@patch( - "inference_endpoint.load_generator.load_generator.LoadGenerator.load_sample_data" -) -def test_max_duration_ms_stops_issuance_with_poisson_scheduler(load_sample_data_mock): - """max_duration_ms should stop iteration even when the scheduler has inter-sample delays. - - Uses PoissonDistributionScheduler at low QPS so each inter-sample wait is measurable. - No sample should be issued after the wall-clock deadline has elapsed. - """ - load_sample_data_mock.side_effect = lambda index, _uuid: index - - max_duration_ms = 200 - target_qps = 50 # ~20ms average inter-sample delay - rt_settings = RuntimeSettings( - metrics.Throughput(target_qps), - reported_metrics=[], - min_duration_ms=0, - max_duration_ms=max_duration_ms, - n_samples_from_dataset=100, - n_samples_to_issue=1_000_000, - min_sample_count=1, - rng_sched=random.Random(42), - rng_sample_index=random.Random(42), - load_pattern=LoadPattern(type=LoadPatternType.POISSON, target_qps=target_qps), - ) - - class CountingIssuer(SampleIssuer): - def issue(self, sample): - pass - - load_generator = SchedulerBasedLoadGenerator( - CountingIssuer(), - None, - scheduler=PoissonDistributionScheduler( - rt_settings, WithoutReplacementSampleOrder - ), - ) - - issued_count = 0 - start = time.monotonic() - for _ in load_generator: - issued_count += 1 - elapsed_s = time.monotonic() - start - - assert ( - issued_count < 1_000_000 - ), f"Expected timeout to stop issuance, but {issued_count} samples were issued" - max_duration_s = max_duration_ms / 1000 - assert ( - elapsed_s >= max_duration_s * 0.5 - ), f"Elapsed time {elapsed_s:.3f}s is unexpectedly below max_duration_ms={max_duration_ms}ms" - assert ( - elapsed_s < max_duration_s * 3 - ), f"Elapsed time {elapsed_s:.3f}s far exceeds max_duration_ms={max_duration_ms}ms" diff --git a/tests/unit/load_generator/test_sample.py b/tests/unit/load_generator/test_sample.py deleted file mode 100644 index 1b2d6115..00000000 --- a/tests/unit/load_generator/test_sample.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import pytest -from inference_endpoint.core.types import QueryResult, StreamChunk -from inference_endpoint.load_generator.sample import ( - Sample, - SampleEvent, - SampleEventHandler, -) - - -def test_sample_uniqueness(): - sample_uuids = [Sample(None).uuid for _ in range(1000)] - assert len(set(sample_uuids)) == len(sample_uuids), "Sample UUIDs should be unique" - - -def test_sample_lazy_data_loading(): - sample = Sample(None) - sample.data = "test_data" - assert sample.data == "test_data" - - with pytest.raises(AttributeError): - sample.data = "test_data2" - - -def test_sample_eager_data_loading(): - sample = Sample("my data") - - with pytest.raises(AttributeError): - sample.data = "test_data2" - - assert sample.data == "my data" - - -def test_sample_callback_times(): - """Test that hooks are invoked in the correct order for stream chunks and completion.""" - events = [] - - sample = Sample(None) - first_chunk = StreamChunk(id=sample.uuid, metadata={"first_chunk": True}) - non_first_chunk = StreamChunk(id=sample.uuid, metadata={"first_chunk": False}) - complete_result = QueryResult(id=sample.uuid) - - def record_first_chunk(chunk): - events.append((SampleEvent.FIRST_CHUNK, time.monotonic_ns())) - - def record_non_first_chunk(chunk): - events.append((SampleEvent.NON_FIRST_CHUNK, time.monotonic_ns())) - - def record_complete(result): - events.append((SampleEvent.COMPLETE, time.monotonic_ns())) - - SampleEventHandler.register_hook(SampleEvent.FIRST_CHUNK, record_first_chunk) - SampleEventHandler.register_hook( - SampleEvent.NON_FIRST_CHUNK, record_non_first_chunk - ) - SampleEventHandler.register_hook(SampleEvent.COMPLETE, record_complete) - - sleep_time_sec = 0.01 - - SampleEventHandler.stream_chunk_complete(first_chunk) - time.sleep(sleep_time_sec) - SampleEventHandler.stream_chunk_complete(non_first_chunk) - time.sleep(sleep_time_sec) - SampleEventHandler.query_result_complete(complete_result) - - assert len(events) == 3 - - assert events[0][0] == SampleEvent.FIRST_CHUNK - assert events[1][0] == SampleEvent.NON_FIRST_CHUNK - assert events[2][0] == SampleEvent.COMPLETE - assert events[0][1] < events[1][1] - assert events[1][1] < events[2][1] - - # Times are in nanoseconds - convert to seconds to compare with sleep time - tpot_1_sec = (events[1][1] - events[0][1]) / 1e9 - tpot_2_sec = (events[2][1] - events[1][1]) / 1e9 - - # Resolution of time.sleep is very coarse, so simply check that duration is - # greater than the sleep time - assert tpot_1_sec > sleep_time_sec - assert tpot_2_sec > sleep_time_sec - - SampleEventHandler.clear_hooks() - - -def test_sample_invalid_type_errors(): - chunk = StreamChunk(id="123", metadata={"first_chunk": True}) - result = QueryResult(id="123") - - with pytest.raises(AssertionError, match="Invalid chunk type"): - SampleEventHandler.stream_chunk_complete(result) - - with pytest.raises(AssertionError, match="Invalid result type"): - SampleEventHandler.query_result_complete(chunk) - - -def test_sample_event_handler_register_hook(): - progress_counter = [0, 0] - - def progress_hook(_): - progress_counter[1] += 1 - - def non_first_chunk_hook(_): - progress_counter[0] += 1 - - SampleEventHandler.register_hook(SampleEvent.COMPLETE, progress_hook) - SampleEventHandler.register_hook(SampleEvent.NON_FIRST_CHUNK, non_first_chunk_hook) - - SampleEventHandler.stream_chunk_complete( - StreamChunk(id="123", metadata={"first_chunk": True}) - ) - assert progress_counter == [0, 0] - - SampleEventHandler.query_result_complete(QueryResult(id="123")) - assert progress_counter == [0, 1] - - SampleEventHandler.stream_chunk_complete( - StreamChunk(id="123", metadata={"first_chunk": True}) - ) - assert progress_counter == [0, 1] - - SampleEventHandler.stream_chunk_complete( - StreamChunk(id="123", metadata={"first_chunk": False}) - ) - assert progress_counter == [1, 1] diff --git a/tests/unit/load_generator/test_sample_order.py b/tests/unit/load_generator/test_sample_order.py new file mode 100644 index 00000000..7dacd38f --- /dev/null +++ b/tests/unit/load_generator/test_sample_order.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for sample_order.py.""" + +import random + +import pytest +from inference_endpoint.load_generator.sample_order import ( + WithoutReplacementSampleOrder, + WithReplacementSampleOrder, +) + + +@pytest.mark.unit +class TestWithoutReplacementSampleOrder: + def test_yields_all_indices(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=5, rng=random.Random(42) + ) + indices = [next(order) for _ in range(5)] + assert sorted(indices) == [0, 1, 2, 3, 4] + + def test_reshuffles_after_exhaustion(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=3, rng=random.Random(42) + ) + first_pass = [next(order) for _ in range(3)] + second_pass = [next(order) for _ in range(3)] + assert sorted(first_pass) == [0, 1, 2] + assert sorted(second_pass) == [0, 1, 2] + + def test_never_raises_stop_iteration(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=2, rng=random.Random(42) + ) + # Should be able to draw far more than dataset size + indices = [next(order) for _ in range(100)] + assert len(indices) == 100 + assert all(0 <= i < 2 for i in indices) + + def test_reproducible_with_seed(self): + order1 = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + order2 = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + seq1 = [next(order1) for _ in range(20)] + seq2 = [next(order2) for _ in range(20)] + assert seq1 == seq2 + + def test_invalid_size_raises(self): + with pytest.raises(ValueError, match="n_samples_in_dataset must be > 0"): + WithoutReplacementSampleOrder(n_samples_in_dataset=0) + + +@pytest.mark.unit +class TestWithReplacementSampleOrder: + def test_yields_valid_indices(self): + order = WithReplacementSampleOrder( + n_samples_in_dataset=5, rng=random.Random(42) + ) + indices = [next(order) for _ in range(100)] + assert all(0 <= i < 5 for i in indices) + + def test_reproducible_with_seed(self): + order1 = WithReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + order2 = WithReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + seq1 = [next(order1) for _ in range(20)] + seq2 = [next(order2) for _ in range(20)] + assert seq1 == seq2 diff --git a/tests/unit/load_generator/test_scheduler.py b/tests/unit/load_generator/test_scheduler.py deleted file mode 100644 index 05de0600..00000000 --- a/tests/unit/load_generator/test_scheduler.py +++ /dev/null @@ -1,257 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import random -import threading - -import pytest -from inference_endpoint.load_generator.sample import SampleEventHandler -from inference_endpoint.load_generator.scheduler import ( - ConcurrencyScheduler, - MaxThroughputScheduler, - PoissonDistributionScheduler, - WithoutReplacementSampleOrder, - WithReplacementSampleOrder, -) -from scipy import stats - - -def test_without_replacement_sample_order(): - ordering = WithoutReplacementSampleOrder(12345, 100) - indices = list(iter(ordering)) - for i in range(0, 12345, 100): - assert len(set(indices[i : i + 100])) == min( - 100, 12345 - i - ), "Indices should be unique, and occur at least once" - - # Assert that order is different in each pass of the dataset - assert ( - indices[:100] != indices[100:200] - ), "Order should be different in each pass of the dataset" - - -def test_with_replacement_sample_order(random_seed): - ordering = WithReplacementSampleOrder(12345, 100, rng=random.Random(random_seed)) - indices = list(iter(ordering)) - - # With Python random.Random(42), the order can be deterministic - assert indices[:10] == [ - 81, - 14, - 3, - 94, - 35, - 31, - 28, - 17, - 94, - 13, - ], "Order does not match expected deterministic order" - # Note with this specific seed and order, 94 occurs twice in the first 10 indices - assert indices[:10].count(94) == 2, "94 should occur twice in the first 10 indices" - - -def test_max_throughput_scheduler(max_throughput_runtime_settings): - scheduler = MaxThroughputScheduler( - max_throughput_runtime_settings, WithReplacementSampleOrder - ) - indices = list(iter(scheduler)) - assert len(indices) == 100 - for _, delay in indices: - assert delay == 0 - assert [s_idx for s_idx, _ in indices[:10]] == [ - 81, - 14, - 3, - 94, - 35, - 31, - 28, - 17, - 94, - 13, - ], "Order does not match expected deterministic order" - - -@pytest.mark.parametrize("target_concurrency", [1, 2, 100, 1000], indirect=True) -def test_concurrency_scheduler(concurrency_runtime_settings, target_concurrency): - """Test ConcurrencyScheduler properly gates issuance by completions.""" - total_samples = concurrency_runtime_settings.n_samples_to_issue - - scheduler = ConcurrencyScheduler( - concurrency_runtime_settings, WithReplacementSampleOrder - ) - - # State tracking - state_lock = threading.RLock() - issued_count = 0 - completed_count = 0 - current_inflight = 0 - max_inflight = 0 - - # Synchronization: signal when queries can complete and when they're done - can_complete = [threading.Event() for _ in range(total_samples)] - completed = [threading.Event() for _ in range(total_samples)] - # Signal when each query is issued - issued = [threading.Event() for _ in range(total_samples)] - - def completion_worker(): - """Waits for signals to complete queries.""" - nonlocal completed_count, current_inflight - - for position in range(total_samples): - can_complete[position].wait() - - with state_lock: - completed_count += 1 - current_inflight -= 1 - assert current_inflight >= 0, "Inflight count went negative" - - scheduler._release_slot() - completed[position].set() - - threading.Thread(target=completion_worker, daemon=True).start() - - def issue_worker(): - """Issues queries through scheduler.""" - nonlocal issued_count, current_inflight, max_inflight - - for position, _ in enumerate(scheduler): - with state_lock: - issued_count += 1 - current_inflight += 1 - max_inflight = max(max_inflight, current_inflight) - assert ( - current_inflight <= target_concurrency - ), f"Concurrency {current_inflight} exceeded limit {target_concurrency}" - issued[position].set() - - issue_thread = threading.Thread(target=issue_worker, daemon=True) - issue_thread.start() - - try: - # Phase 1: First target_concurrency queries issue immediately - for position in range(target_concurrency): - issued[position].wait() - - with state_lock: - assert issued_count == target_concurrency - assert completed_count == 0 - assert current_inflight == target_concurrency - - # Phase 2: Verify scheduler blocks when at capacity, unblocks on completion - for position in range(target_concurrency, total_samples): - position_to_complete = position - target_concurrency - - # Verify next query hasn't issued yet (scheduler is blocking) - assert not issued[ - position - ].is_set(), f"Query {position} issued before slot was freed" - - # Free a slot - can_complete[position_to_complete].set() - completed[position_to_complete].wait() - - # Verify next query now issues - issued[position].wait() - - with state_lock: - assert current_inflight == target_concurrency - - # Phase 3: Complete remaining queries and cleanup - for position in range(target_concurrency, total_samples): - can_complete[position].set() - completed[position].wait() - - issue_thread.join() - - # Final validation - with state_lock: - assert issued_count == total_samples - assert completed_count == total_samples - assert current_inflight == 0 - assert max_inflight == target_concurrency - - finally: - SampleEventHandler.clear_hooks() - - -@pytest.mark.parametrize("target_qps", [50.0, 100.0, 500.0, 1000.0], indirect=True) -def test_poisson_scheduler_distribution(poisson_runtime_settings, target_qps): - """Test PoissonDistributionScheduler produces exponentially distributed inter-arrival times. - - For a Poisson process with rate λ (target QPS), inter-arrival times must follow - exponential distribution with mean = 1/λ. - - Three-tier validation: - 1. Mean with 99.9% confidence interval - 2. Coefficient of Variation (CV) ≈ 1.0 (exponential signature) - 3. Kolmogorov-Smirnov test for distribution shape - """ - scheduler = PoissonDistributionScheduler( - poisson_runtime_settings, WithReplacementSampleOrder - ) - - # Test configuration - TARGET_QPS = target_qps - expected_mean_s = 1.0 / TARGET_QPS - - # Collect delays from scheduler (in seconds) for statistical analysis - delays_s = [] - for _, delay_ns in scheduler: - delays_s.append(delay_ns / 1e9) # Convert ns to seconds - - # Validate sufficient sample size - n = len(delays_s) - - # Calculate sample statistics using Bessel's correction for unbiased variance (whitened) - sample_mean = sum(delays_s) / n - sample_variance = sum((x - sample_mean) ** 2 for x in delays_s) / (n - 1) - sample_std = math.sqrt(sample_variance) - cv = sample_std / sample_mean - - # Test 1: Mean with statistical confidence interval (99.9% CI) - # For exponential: std(X̄) = sigma/√n = mu/√n - z_critical = 3.29 # 99.9% two-tailed - margin_of_error = z_critical * (sample_std / math.sqrt(n)) - assert abs(sample_mean - expected_mean_s) < margin_of_error, ( - f"Mean {sample_mean*1000:.3f}ms outside 99.9% CI: " - f"[{(expected_mean_s - margin_of_error)*1000:.3f}, " - f"{(expected_mean_s + margin_of_error)*1000:.3f}] ms" - ) - - # Test 2: CV should be close to 1.0 (exponential property: std = mean) - # Use adaptive tolerance based on sample size, max(10%, 1 std. error) - cv_tolerance = max(0.10, 1.0 / math.sqrt(n)) - assert ( - abs(cv - 1.0) < cv_tolerance - ), f"CV {cv:.3f} deviates from 1.0 by more than {cv_tolerance:.3f}" - - # Test 3: Kolmogorov-Smirnov test for exponential distribution - # kstest compares data against exponential CDF with scale parameter = mean - ks_statistic, p_value = stats.kstest( - delays_s, - "expon", - args=(0, sample_mean), # loc=0 (no shift), scale=mean - alternative="two-sided", - ) - - # Reject if p-value < 0.0001 (99.99% confidence that distribution is NOT exponential) - ALPHA = 0.0001 - assert p_value > ALPHA, ( - f"KS test rejected exponential distribution: " - f"p-value={p_value:.4f} < alpha={ALPHA} (D={ks_statistic:.4f})" - ) diff --git a/tests/unit/load_generator/test_session.py b/tests/unit/load_generator/test_session.py deleted file mode 100644 index 22554b4c..00000000 --- a/tests/unit/load_generator/test_session.py +++ /dev/null @@ -1,144 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random - -import inference_endpoint.metrics as metrics -import pytest -from inference_endpoint.config.runtime_settings import RuntimeSettings -from inference_endpoint.config.schema import LoadPattern, LoadPatternType -from inference_endpoint.load_generator.sample import ( - Sample, - SampleEvent, - SampleEventHandler, -) -from inference_endpoint.load_generator.scheduler import ( - MaxThroughputScheduler, - WithoutReplacementSampleOrder, -) -from inference_endpoint.load_generator.session import BenchmarkSession -from tqdm import tqdm - -from tests.test_helpers import ( - DummyDataLoader, - PooledSampleIssuer, -) - -# The following are tests for PooledSampleIssuer in test_helpers.py. If these tests pass -# but session.py tests fail, it's probably not the PooledSampleIssuer's fault. - - -def test_pooled_issuer_exception_propagation(): - """Test that exceptions in worker threads are properly propagated to the main thread.""" - - def failing_compute(sample): - raise ValueError("Worker thread error!") - - issuer = PooledSampleIssuer(compute_func=failing_compute, n_workers=2) - - sample1 = Sample(b"sample1") - sample2 = Sample(b"sample2") - - # Submit some work that will fail - issuer.issue(sample1) - issuer.issue(sample2) - - # Shutdown should raise the exception from the worker thread - with pytest.raises(ValueError, match="Worker thread error!"): - issuer.shutdown() - - -def test_pooled_issuer_futures_cleanup(): - """Test that completed futures are cleaned up to prevent memory leaks.""" - import time - - def slow_compute(sample): - time.sleep(0.01) # Small delay - return [sample.decode("utf-8")] - - issuer = PooledSampleIssuer(compute_func=slow_compute, n_workers=4) - - # Submit 250 samples (should trigger cleanup at 100 and 200) - for _ in range(250): - issuer.issue(Sample(b"sample")) - - # Let some time pass first - time.sleep(0.2) - - # Manually check errors to trigger cleanup - issuer.check_errors() - - for _ in range(250): - issuer.issue(Sample(b"sample")) - - issuer.shutdown() - - # After shutdown, all futures should be cleared - assert len(issuer.futures) == 0, "Futures not cleared after shutdown" - - -# session.py tests - - -def test_session_start(clean_sample_event_hooks): - rt_settings = RuntimeSettings( - metrics.Throughput(5000), - [metrics.Throughput(5000)], - min_duration_ms=1000, - max_duration_ms=None, - n_samples_from_dataset=100, - n_samples_to_issue=10_000, - min_sample_count=100, - rng_sched=random.Random(1234), - rng_sample_index=random.Random(1234), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - - def compute_digits_of_square(n: int): - yield from str(n**2) - - dl = DummyDataLoader(n_samples=100) - sample_issuer = PooledSampleIssuer(compute_digits_of_square) - sched = MaxThroughputScheduler(rt_settings, WithoutReplacementSampleOrder) - - class ProgressBarHook: - def __init__(self, pbar: tqdm | None = None): - self.pbar = pbar - - def __call__(self, _): - if isinstance(self.pbar, tqdm): - self.pbar.update(1) - - def set_pbar(self, pbar: tqdm): - self.pbar = pbar - - pbar_hook = ProgressBarHook() - SampleEventHandler.register_hook(SampleEvent.COMPLETE, pbar_hook) - - with tqdm(desc="pytest_test_session_start", total=10_000, unit="samples") as pbar: - pbar_hook.set_pbar(pbar) - sess = BenchmarkSession.start( - rt_settings, - dl, - sample_issuer, - sched, - name="pytest_test_session_start", - ) - assert sess.wait_for_test_end( - timeout=120.0 - ), "Session did not complete within timeout" - - # Shutdown the sample issuer to ensure proper cleanup and error propagation - sample_issuer.shutdown() diff --git a/tests/unit/load_generator/test_strategy.py b/tests/unit/load_generator/test_strategy.py new file mode 100644 index 00000000..650a92c1 --- /dev/null +++ b/tests/unit/load_generator/test_strategy.py @@ -0,0 +1,701 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for load strategies.""" + +from __future__ import annotations + +import asyncio +import random +from collections.abc import Callable +from time import monotonic_ns + +import pytest +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import LoadPattern, LoadPatternType +from inference_endpoint.load_generator.delay import make_delay_fn, poisson_delay_fn +from inference_endpoint.load_generator.sample_order import WithoutReplacementSampleOrder +from inference_endpoint.load_generator.strategy import ( + BurstStrategy, + ConcurrencyStrategy, + TimedIssueStrategy, + create_load_strategy, +) +from inference_endpoint.metrics.metric import Throughput + + +def _constant_delay(ns: int = 1_000) -> Callable[[], int]: + return lambda: ns + + +# --------------------------------------------------------------------------- +# Mock PhaseIssuer +# --------------------------------------------------------------------------- + + +class MockPhaseIssuer: + """Minimal PhaseIssuer for strategy tests.""" + + def __init__(self, max_issues: int = 100): + self.issued_indices: list[int] = [] + self.issued_count: int = 0 + self._max = max_issues + + def issue(self, sample_index: int) -> str | None: + if self.issued_count >= self._max: + return None + self.issued_indices.append(sample_index) + self.issued_count += 1 + return f"q{self.issued_count}" + + +# --------------------------------------------------------------------------- +# TimedIssueStrategy +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTimedIssueStrategyCallAt: + @pytest.mark.asyncio + async def test_issues_correct_count(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + delay_fn = _constant_delay(1_000) + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=False) + + issuer = MockPhaseIssuer(max_issues=20) + count = await strategy.execute(issuer) + assert count == 20 + assert issuer.issued_count == 20 + + @pytest.mark.asyncio + async def test_stops_on_none(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=5, rng=random.Random(42) + ) + delay_fn = _constant_delay(1_000) + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=False) + + issuer = MockPhaseIssuer(max_issues=3) + count = await strategy.execute(issuer) + assert count == 3 + + @pytest.mark.asyncio + async def test_timing_precision(self): + """call_at should achieve sub-ms precision for moderate delays.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=100, rng=random.Random(42) + ) + delay_fn = _constant_delay(1_000_000) + + timestamps: list[int] = [] + + class TimingIssuer: + issued_count = 0 + + def issue(self, idx): + timestamps.append(monotonic_ns()) + self.issued_count += 1 + if self.issued_count >= 10: + return None + return f"q{self.issued_count}" + + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=False) + await strategy.execute(TimingIssuer()) + + # Check inter-arrival times are positive (callbacks fire in order) + for i in range(1, len(timestamps)): + delta_ns = timestamps[i] - timestamps[i - 1] + assert delta_ns > 0, f"Issue {i}: non-monotonic timestamps" + # Total elapsed should be roughly 9ms (9 delays of 1ms) + total_ns = timestamps[-1] - timestamps[0] + assert ( + total_ns > 5_000_000 + ), f"Total elapsed {total_ns}ns too small for 9x1ms delays" + + +@pytest.mark.unit +class TestTimedIssueStrategyExecutor: + @pytest.mark.asyncio + async def test_issues_correct_count(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + delay_fn = _constant_delay(1_000) + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=True) + + issuer = MockPhaseIssuer(max_issues=20) + count = await strategy.execute(issuer) + assert count == 20 + + +# --------------------------------------------------------------------------- +# BurstStrategy +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBurstStrategy: + @pytest.mark.asyncio + async def test_issues_all(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + issuer = MockPhaseIssuer(max_issues=50) + count = await strategy.execute(issuer) + assert count == 50 + + @pytest.mark.asyncio + async def test_stops_on_none(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=5, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + issuer = MockPhaseIssuer(max_issues=7) + count = await strategy.execute(issuer) + assert count == 7 + + @pytest.mark.asyncio + async def test_does_not_starve_event_loop(self): + """Verify other coroutines get to run during burst issuance. + + We schedule a coroutine that increments a counter each time it wakes. + If burst issuance yields properly, the counter should be > 0 before + issuance completes. + """ + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=200, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + wakeup_count = 0 + stop = asyncio.Event() + + async def competing_task(): + nonlocal wakeup_count + while not stop.is_set(): + await asyncio.sleep(0) + wakeup_count += 1 + + task = asyncio.create_task(competing_task()) + issuer = MockPhaseIssuer(max_issues=200) + await strategy.execute(issuer) + stop.set() + await task + # The competing task should have woken up multiple times during issuance + assert wakeup_count > 1, f"Competing task only woke {wakeup_count} times" + + +# --------------------------------------------------------------------------- +# ConcurrencyStrategy +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestConcurrencyStrategy: + @pytest.mark.asyncio + async def test_issues_up_to_concurrency_then_waits(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = ConcurrencyStrategy(target_concurrency=3, sample_order=order) + issuer = MockPhaseIssuer(max_issues=10) + + # Start strategy but don't await — it should block after 3 issues + task = asyncio.create_task(strategy.execute(issuer)) + await asyncio.sleep(0.01) # let it run + assert issuer.issued_count == 3 + + # Simulate completions + for i in range(1, 4): + strategy.on_query_complete(f"q{i}") + await asyncio.sleep(0.01) + assert issuer.issued_count == 6 + + # Complete remaining + for i in range(4, 11): + strategy.on_query_complete(f"q{i}") + count = await asyncio.wait_for(task, timeout=2.0) + assert count == 10 + + @pytest.mark.asyncio + async def test_stops_on_none(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=100, rng=random.Random(42) + ) + strategy = ConcurrencyStrategy(target_concurrency=5, sample_order=order) + issuer = MockPhaseIssuer(max_issues=3) + + # Complete queries as they arrive so strategy doesn't block + async def completer(): + while True: + await asyncio.sleep(0.005) + for i in range(1, issuer.issued_count + 1): + strategy.on_query_complete(f"q{i}") + + completer_task = asyncio.create_task(completer()) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=2.0) + completer_task.cancel() + try: + await completer_task + except asyncio.CancelledError: + pass + assert count == 3 + + def test_invalid_concurrency_raises(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + with pytest.raises(ValueError, match="target_concurrency must be > 0"): + ConcurrencyStrategy(target_concurrency=0, sample_order=order) + + +# --------------------------------------------------------------------------- +# Delay functions +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestDelayFunctions: + def test_poisson_delay_positive(self): + fn = poisson_delay_fn(1000.0, random.Random(42)) + delays = [fn() for _ in range(100)] + assert all(d >= 1 for d in delays) + + def test_poisson_delay_mean(self): + """Mean delay should be close to 1/target_qps in ns.""" + target_qps = 10_000.0 + fn = poisson_delay_fn(target_qps, random.Random(42)) + delays = [fn() for _ in range(10_000)] + mean_ns = sum(delays) / len(delays) + expected_ns = 1e9 / target_qps # 100_000 ns + assert abs(mean_ns - expected_ns) / expected_ns < 0.1 # within 10% + + def test_poisson_delay_invalid_qps(self): + with pytest.raises(ValueError, match="target_qps must be > 0"): + poisson_delay_fn(0, random.Random(42)) + + def test_make_delay_fn_unsupported_pattern(self): + lp = LoadPattern(type=LoadPatternType.MAX_THROUGHPUT) + with pytest.raises(ValueError, match="No delay function"): + make_delay_fn(lp, random.Random(42)) + + +# --------------------------------------------------------------------------- +# create_load_strategy factory +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCreateLoadStrategy: + def test_max_throughput(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings(LoadPattern(type=LoadPatternType.MAX_THROUGHPUT)) + strategy = create_load_strategy(settings, loop) + assert isinstance(strategy, BurstStrategy) + finally: + loop.close() + + def test_poisson_default(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings( + LoadPattern(type=LoadPatternType.POISSON, target_qps=1000.0) + ) + strategy = create_load_strategy(settings, loop) + assert isinstance(strategy, TimedIssueStrategy) + assert not strategy._use_executor + finally: + loop.close() + + def test_poisson_executor(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings( + LoadPattern(type=LoadPatternType.POISSON, target_qps=1000.0) + ) + strategy = create_load_strategy(settings, loop, use_executor=True) + assert isinstance(strategy, TimedIssueStrategy) + assert strategy._use_executor + finally: + loop.close() + + def test_concurrency(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings( + LoadPattern(type=LoadPatternType.CONCURRENCY, target_concurrency=32) + ) + strategy = create_load_strategy(settings, loop) + assert isinstance(strategy, ConcurrencyStrategy) + assert strategy._target == 32 + finally: + loop.close() + + def test_no_load_pattern_raises(self): + loop = asyncio.new_event_loop() + try: + settings = _make_settings(None) + with pytest.raises(ValueError, match="load_pattern must not be None"): + create_load_strategy(settings, loop) + finally: + loop.close() + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestEdgeCases: + @pytest.mark.asyncio + async def test_burst_single_sample(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=1, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + issuer = MockPhaseIssuer(max_issues=1) + count = await strategy.execute(issuer) + assert count == 1 + + @pytest.mark.asyncio + async def test_burst_stop_immediately(self): + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + issuer = MockPhaseIssuer(max_issues=0) + count = await strategy.execute(issuer) + assert count == 0 + + @pytest.mark.asyncio + async def test_burst_exception_in_issue_does_not_hang(self): + """If issue() raises, strategy should not hang forever.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + class FailingIssuer: + issued_count = 0 + + def issue(self, idx: int) -> str | None: + self.issued_count += 1 + if self.issued_count == 3: + raise RuntimeError("load_sample failed") + return f"q{self.issued_count}" + + issuer = FailingIssuer() + # Must not hang — should complete (with error) within timeout + with pytest.raises(RuntimeError, match="load_sample failed"): + await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + + @pytest.mark.asyncio + async def test_timed_call_at_exception_in_issue_does_not_hang(self): + """If issue() raises in call_at callback, strategy should not hang.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = TimedIssueStrategy( + _constant_delay(1_000), order, loop, use_executor=False + ) + + class FailingIssuer: + issued_count = 0 + + def issue(self, idx: int) -> str | None: + self.issued_count += 1 + if self.issued_count == 3: + raise RuntimeError("load_sample failed") + return f"q{self.issued_count}" + + issuer = FailingIssuer() + with pytest.raises(RuntimeError, match="load_sample failed"): + await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + + def test_sample_order_single_element(self): + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=1, rng=random.Random(42) + ) + indices = [next(order) for _ in range(10)] + assert all(i == 0 for i in indices) + + +# --------------------------------------------------------------------------- +# Executor mode exceptions +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTimedIssueStrategyExecutorExceptions: + @pytest.mark.asyncio + async def test_executor_issue_raises(self): + """If issue() raises inside run_in_executor path, exception propagates.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + strategy = TimedIssueStrategy( + _constant_delay(1_000), order, loop, use_executor=True + ) + + call_count = 0 + + class FailingIssuer: + issued_count = 0 + + def issue(self, idx: int) -> str | None: + nonlocal call_count + call_count += 1 + self.issued_count += 1 + if call_count == 3: + raise ValueError("executor callback failed") + return f"q{call_count}" + + with pytest.raises(ValueError, match="executor callback failed"): + await asyncio.wait_for(strategy.execute(FailingIssuer()), timeout=5.0) + + @pytest.mark.asyncio + async def test_executor_delay_fn_raises(self): + """If delay_fn raises inside executor path, exception propagates.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=10, rng=random.Random(42) + ) + call_count = 0 + + def bad_delay(): + nonlocal call_count + call_count += 1 + if call_count == 3: + raise RuntimeError("delay computation failed") + return 1_000 + + strategy = TimedIssueStrategy(bad_delay, order, loop, use_executor=True) + issuer = MockPhaseIssuer(max_issues=100) + + with pytest.raises(RuntimeError, match="delay computation failed"): + await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + + +# --------------------------------------------------------------------------- +# Concurrent on_query_complete calls +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestConcurrencyStrategyConcurrentCompletions: + @pytest.mark.asyncio + async def test_multiple_completions_simultaneously(self): + """Multiple on_query_complete calls arriving at the same time.""" + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=20, rng=random.Random(42) + ) + strategy = ConcurrencyStrategy(target_concurrency=5, sample_order=order) + issuer = MockPhaseIssuer(max_issues=20) + + task = asyncio.create_task(strategy.execute(issuer)) + + # Let strategy issue initial batch of 5 + await asyncio.sleep(0.02) + assert issuer.issued_count == 5 + + # Release all 5 at once + for i in range(1, 6): + strategy.on_query_complete(f"q{i}") + await asyncio.sleep(0.02) + assert issuer.issued_count == 10 + + # Release next batch all at once + for i in range(6, 11): + strategy.on_query_complete(f"q{i}") + await asyncio.sleep(0.02) + assert issuer.issued_count == 15 + + # Release rest + for i in range(11, 21): + strategy.on_query_complete(f"q{i}") + count = await asyncio.wait_for(task, timeout=2.0) + assert count == 20 + + @pytest.mark.asyncio + async def test_completions_interleaved_with_issues(self): + """Completions arriving while new issues are being scheduled.""" + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=50, rng=random.Random(42) + ) + strategy = ConcurrencyStrategy(target_concurrency=2, sample_order=order) + issuer = MockPhaseIssuer(max_issues=10) + + task = asyncio.create_task(strategy.execute(issuer)) + await asyncio.sleep(0.01) + assert issuer.issued_count == 2 + + # Alternate: complete one, let it issue one more + for i in range(1, 11): + strategy.on_query_complete(f"q{i}") + await asyncio.sleep(0.005) + + count = await asyncio.wait_for(task, timeout=2.0) + assert count == 10 + + +# --------------------------------------------------------------------------- +# Near-zero delay (high QPS poisson) +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTimedIssueStrategyNearZeroDelay: + @pytest.mark.asyncio + async def test_very_high_qps(self): + """Poisson with extremely high QPS should still issue all samples.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=50, rng=random.Random(42) + ) + # 1ns delay -- essentially zero + strategy = TimedIssueStrategy( + _constant_delay(1), order, loop, use_executor=False + ) + issuer = MockPhaseIssuer(max_issues=50) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + assert count == 50 + + @pytest.mark.asyncio + async def test_very_high_qps_executor(self): + """Near-zero delay in executor mode.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=50, rng=random.Random(42) + ) + strategy = TimedIssueStrategy( + _constant_delay(1), order, loop, use_executor=True + ) + issuer = MockPhaseIssuer(max_issues=50) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + assert count == 50 + + @pytest.mark.asyncio + async def test_poisson_high_qps_statistical(self): + """Real poisson distribution at 1M QPS should complete quickly.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=100, rng=random.Random(42) + ) + delay_fn = poisson_delay_fn(1_000_000.0, random.Random(42)) + strategy = TimedIssueStrategy(delay_fn, order, loop, use_executor=False) + issuer = MockPhaseIssuer(max_issues=100) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + assert count == 100 + + +# --------------------------------------------------------------------------- +# Large-scale burst +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBurstStrategyLargeScale: + @pytest.mark.asyncio + async def test_burst_1000_samples(self): + """BurstStrategy should handle 1000+ samples without issues.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=200, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + issuer = MockPhaseIssuer(max_issues=1000) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=10.0) + assert count == 1000 + + @pytest.mark.asyncio + async def test_burst_5000_samples(self): + """BurstStrategy at 5000 samples -- verify count and no event loop starvation.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=500, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + + wakeups = 0 + stop = asyncio.Event() + + async def observer(): + nonlocal wakeups + while not stop.is_set(): + await asyncio.sleep(0) + wakeups += 1 + + obs_task = asyncio.create_task(observer()) + issuer = MockPhaseIssuer(max_issues=5000) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=10.0) + stop.set() + await obs_task + + assert count == 5000 + assert wakeups > 10, f"Event loop starved: observer only ran {wakeups} times" + + @pytest.mark.asyncio + async def test_burst_indices_wrap_around(self): + """With dataset_size < issue_count, indices should wrap around.""" + loop = asyncio.get_running_loop() + order = WithoutReplacementSampleOrder( + n_samples_in_dataset=3, rng=random.Random(42) + ) + strategy = BurstStrategy(order, loop) + issuer = MockPhaseIssuer(max_issues=10) + count = await asyncio.wait_for(strategy.execute(issuer), timeout=5.0) + assert count == 10 + # All indices should be 0, 1, or 2 + assert all(0 <= idx <= 2 for idx in issuer.issued_indices) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_settings(load_pattern): + """Create minimal RuntimeSettings for factory tests.""" + return RuntimeSettings( + metric_target=Throughput(100), + reported_metrics=[], + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=10, + n_samples_to_issue=10, + min_sample_count=10, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=load_pattern, + ) diff --git a/tests/unit/metrics/test_report_builder.py b/tests/unit/metrics/test_report_builder.py index 6e4e79e9..c4cb276d 100644 --- a/tests/unit/metrics/test_report_builder.py +++ b/tests/unit/metrics/test_report_builder.py @@ -77,23 +77,26 @@ def _make_store(tmp_path: Path, n_samples: int = 50): store_dir = tmp_path / "kv" w = BasicKVStore(store_dir) + # Counter keys matching MetricCounterKey enum for key in [ - "n_samples_issued", - "n_samples_completed", - "n_samples_failed", - "duration_ns", - "test_started_at", + "total_samples_issued", + "total_samples_completed", + "total_samples_failed", + "tracked_samples_issued", + "tracked_samples_completed", + "tracked_duration_ns", + "total_duration_ns", ]: w.create_key(key, "counter") for key in ["ttft_ns", "sample_latency_ns", "osl", "isl", "chunk_delta_ns"]: w.create_key(key, "series") w.create_key("tpot_ns", "series", dtype=float) - w.update("n_samples_issued", n_samples) - w.update("n_samples_completed", n_samples) - w.update("n_samples_failed", 0) + w.update("tracked_samples_issued", n_samples) + w.update("tracked_samples_completed", n_samples) + w.update("total_samples_failed", 0) if n_samples > 0: - w.update("duration_ns", 10_000_000_000) + w.update("tracked_duration_ns", 10_000_000_000) for i in range(n_samples): w.update("ttft_ns", 1_000_000 + i * 10_000) @@ -102,11 +105,13 @@ def _make_store(tmp_path: Path, n_samples: int = 50): r = BasicKVStoreReader(store_dir) for key in [ - "n_samples_issued", - "n_samples_completed", - "n_samples_failed", - "duration_ns", - "test_started_at", + "total_samples_issued", + "total_samples_completed", + "total_samples_failed", + "tracked_samples_issued", + "tracked_samples_completed", + "tracked_duration_ns", + "total_duration_ns", ]: r.register_key(key, "counter") for key in ["ttft_ns", "sample_latency_ns", "osl", "isl", "chunk_delta_ns"]: diff --git a/tests/unit/test_core_types.py b/tests/unit/test_core_types.py index f1a574bf..99f1c169 100644 --- a/tests/unit/test_core_types.py +++ b/tests/unit/test_core_types.py @@ -97,11 +97,10 @@ class TestStreamChunk: def test_stream_chunk_creation(self) -> None: """Test creating a stream chunk.""" - chunk = StreamChunk(id="test-123", response_chunk="partial", is_complete=False) + chunk = StreamChunk(id="test-123", response_chunk="partial") assert chunk.id == "test-123" assert chunk.response_chunk == "partial" - assert chunk.is_complete is False assert chunk.metadata == {} diff --git a/tests/unit/transport/test_zmq_pool_transport.py b/tests/unit/transport/test_zmq_pool_transport.py new file mode 100644 index 00000000..db816f04 --- /dev/null +++ b/tests/unit/transport/test_zmq_pool_transport.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ZmqWorkerPoolTransport and ReadyCheckReceiver. + +Includes regression test for the 'Socket operation on non-socket' bug where +ReadyCheckReceiver.wait() closed its socket on TimeoutError, breaking the +retry loop in WorkerManager._wait_for_workers_with_liveness_check(). +""" + +import asyncio +import uuid + +import pytest +import zmq +from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext +from inference_endpoint.async_utils.transport.zmq.pubsub import ( + ZmqEventRecordPublisher, +) +from inference_endpoint.async_utils.transport.zmq.ready_check import ( + ReadyCheckReceiver, +) +from inference_endpoint.async_utils.transport.zmq.transport import ( + ZMQTransportConfig, + ZmqWorkerPoolTransport, +) + + +@pytest.fixture(autouse=True) +def reset_zmq_singleton(): + """Ensure each test gets a fresh ManagedZMQContext singleton.""" + yield + instance = ManagedZMQContext._instance + if instance is not None and getattr(instance, "_initialized", False): + instance.cleanup() + ManagedZMQContext._instance = None + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestReadyCheckReceiverTimeout: + """Regression: ReadyCheckReceiver must survive timeout for retry.""" + + async def test_socket_survives_timeout(self): + """After wait() times out, the socket must still be usable for retry. + + This is the core regression test for the ENOTSOCK bug. The old code + had `except BaseException: self.close()` which closed the socket on + TimeoutError. The caller (_wait_for_workers_with_liveness_check) + catches TimeoutError and retries, hitting a dead socket. + """ + zmq_ctx = ManagedZMQContext(io_threads=1) + dummy = zmq_ctx.socket(zmq.PUB) + zmq_ctx.bind(dummy, "dummy_pub") + + receiver = ReadyCheckReceiver("ready_test", zmq_ctx, count=1) + + # First wait should timeout (no signals sent) + with pytest.raises(TimeoutError): + await receiver.wait(timeout=0.05) + + # Socket must still be usable after timeout + assert not receiver._sock.closed, ( + "ReadyCheckReceiver closed its socket on TimeoutError — " + "this breaks the retry loop in _wait_for_workers_with_liveness_check" + ) + _ = receiver._sock.rcvtimeo # Would raise ENOTSOCK if socket is dead + + # Second wait should also timeout cleanly (not ENOTSOCK) + with pytest.raises(TimeoutError): + await receiver.wait(timeout=0.05) + + receiver.close() + dummy.close() + zmq_ctx.cleanup() + + async def test_socket_closed_on_cancellation(self): + """Socket SHOULD be closed on non-timeout exceptions (e.g. cancel).""" + zmq_ctx = ManagedZMQContext(io_threads=1) + dummy = zmq_ctx.socket(zmq.PUB) + zmq_ctx.bind(dummy, "dummy_pub") + + receiver = ReadyCheckReceiver("ready_test", zmq_ctx, count=1) + + task = asyncio.create_task(receiver.wait(timeout=10.0)) + await asyncio.sleep(0.05) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert receiver._sock.closed + + dummy.close() + zmq_ctx.cleanup() + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestZmqPoolTransportWithPublisher: + """Test pool transport creation with a publisher on the same context.""" + + async def _create_publisher_and_pool( + self, loop: asyncio.AbstractEventLoop, num_workers: int + ): + """Helper: create publisher + pool transport, test ready check socket.""" + sid = uuid.uuid4().hex[:8] + zmq_ctx = ManagedZMQContext(io_threads=2) + publisher = ZmqEventRecordPublisher(f"ev_pub_{sid}", zmq_ctx, loop=loop) + + pool = ZmqWorkerPoolTransport.create( + loop, num_workers, config=ZMQTransportConfig() + ) + + rc = pool._ready_check + assert not rc._sock.closed + _ = rc._sock.rcvtimeo + + with pytest.raises(TimeoutError): + await pool.wait_for_workers_ready(timeout=0.1) + + pool.cleanup() + publisher.close() + zmq_ctx.cleanup() + + async def test_2_workers(self): + loop = asyncio.get_running_loop() + await self._create_publisher_and_pool(loop, 2) + + async def test_3_workers(self): + loop = asyncio.get_running_loop() + await self._create_publisher_and_pool(loop, 3) + + async def test_4_workers(self): + loop = asyncio.get_running_loop() + await self._create_publisher_and_pool(loop, 4) + + async def test_8_workers(self): + loop = asyncio.get_running_loop() + await self._create_publisher_and_pool(loop, 8) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestZmqPoolTransportWithoutPublisher: + """Test pool transport creation without a publisher (baseline).""" + + @pytest.mark.parametrize("num_workers", [2, 3, 4, 8]) + async def test_pool_only(self, num_workers: int): + loop = asyncio.get_running_loop() + zmq_ctx = ManagedZMQContext(io_threads=2) + dummy = zmq_ctx.socket(zmq.PUB) + zmq_ctx.bind(dummy, "dummy") + + pool = ZmqWorkerPoolTransport.create( + loop, num_workers, config=ZMQTransportConfig() + ) + + rc = pool._ready_check + assert not rc._sock.closed + _ = rc._sock.rcvtimeo + + with pytest.raises(TimeoutError): + await pool.wait_for_workers_ready(timeout=0.1) + + pool.cleanup() + dummy.close() + zmq_ctx.cleanup()