diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 5835a9e8..8ec75145 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -488,6 +488,10 @@ def run_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> None: type(config).__name__, config.model_dump_json(indent=2, exclude_none=True), ) + from inference_endpoint.async_utils.runner import run_async + + from .execute_async import run_benchmark_async + ctx = setup_benchmark(config, test_mode) - report, collector = run_benchmark_threaded(ctx) + report, collector = run_async(run_benchmark_async(ctx)) finalize_benchmark(ctx, report, collector) diff --git a/src/inference_endpoint/commands/benchmark/execute_async.py b/src/inference_endpoint/commands/benchmark/execute_async.py new file mode 100644 index 00000000..9d6ba4b0 --- /dev/null +++ b/src/inference_endpoint/commands/benchmark/execute_async.py @@ -0,0 +1,500 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async benchmark runner — single uvloop, no threads in the main process. + +Architecture: + - HTTPEndpointClient on the running loop (no separate loop thread) + - ZmqEventRecordPublisher for non-blocking event publishing + - AsyncEventRecorder in a background process for SQLite writes + - Scheduler.__aiter__() for drift-correcting online timing + - Unified receiver: await recv() wakeup + poll() drain +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import random +import signal +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any +from urllib.parse import urljoin + +from tqdm import tqdm + +from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext +from inference_endpoint.async_utils.transport.zmq.pubsub import ( + ZmqEventRecordPublisher, +) +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import ( + APIType, + LoadPatternType, + SystemDefaults, +) +from inference_endpoint.core.record import SampleEventType, SessionEventType +from inference_endpoint.core.types import Query, QueryResult, StreamChunk +from inference_endpoint.dataset_manager.dataset import Dataset +from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient +from inference_endpoint.exceptions import ExecutionError +from inference_endpoint.load_generator.scheduler import ( + Scheduler, + WithoutReplacementSampleOrder, +) +from inference_endpoint.metrics.async_recorder import AsyncEventRecorder +from inference_endpoint.metrics.async_reporter import AsyncEventReporter +from inference_endpoint.metrics.reporter import MetricsReporter + +from .execute import BenchmarkContext, ResponseCollector + +logger = logging.getLogger(__name__) + +_NO_RECORD = os.environ.get("NO_RECORD", "") + + +# ── Runtime state ──────────────────────────────────────────────────────── + + +@dataclass +class _BenchmarkRuntime: + """Mutable state shared across sender/receiver coroutines.""" + + http_client: HTTPEndpointClient + recorder: AsyncEventReporter + scheduler: Scheduler + collector: ResponseCollector + dataloader: Dataset + uuid_to_index: dict[str, int] = field(default_factory=dict) + rng: random.Random = field(default_factory=random.Random) + send_done: bool = False + send_n: int = 0 + stop_requested: bool = False + + def issue_sample(self, s_idx: int, ds: Dataset, uuid_map: dict[str, int]) -> None: + sample_uuid = self.rng.randbytes(16).hex() + sample_data = ds.load_sample(s_idx) + if not _NO_RECORD: + self.recorder.record_event( + SampleEventType.ISSUED, + time.monotonic_ns(), + sample_uuid=sample_uuid, + ) + self.http_client.issue(Query(id=sample_uuid, data=sample_data)) + uuid_map[sample_uuid] = s_idx + + def handle_response(self, result: QueryResult | StreamChunk) -> None: + ts = time.monotonic_ns() + if isinstance(result, StreamChunk): + ev = ( + SampleEventType.RECV_FIRST + if (result.metadata or {}).get("first_chunk", False) + else SampleEventType.RECV_NON_FIRST + ) + self.recorder.record_event(ev, ts, sample_uuid=result.id) + elif isinstance(result, QueryResult): + if result.error is not None: + logger.error(f"Error in request {result.id}: {result.error}") + self.recorder.record_event( + SampleEventType.COMPLETE, + ts, + sample_uuid=result.id, + data=result.response_output, + ) + self.collector.on_complete_hook(result) + self.scheduler.notify_complete() + + +# ── Sender factories ──────────────────────────────────────────────────── + + +def _make_sender( + rt: _BenchmarkRuntime, + load_pattern: LoadPatternType, +) -> Any: + """Return a sender coroutine matched to the load pattern.""" + + if load_pattern == LoadPatternType.MAX_THROUGHPUT: + + async def _sender() -> None: + sent = 0 + for s_idx, _ in rt.scheduler: + if rt.stop_requested: + break + rt.issue_sample(s_idx, rt.dataloader, rt.uuid_to_index) + sent += 1 + if sent % 1000 == 0: + await asyncio.sleep(0) + rt.send_n = sent + rt.send_done = True + + elif load_pattern == LoadPatternType.CONCURRENCY: + + async def _sender() -> None: + sent = 0 + async for s_idx in rt.scheduler: + if rt.stop_requested: + break + rt.issue_sample(s_idx, rt.dataloader, rt.uuid_to_index) + sent += 1 + if sent % 1000 == 0: + await asyncio.sleep(0) + rt.send_n = sent + rt.send_done = True + + else: + # Poisson: scheduler.__aiter__ handles timing + async def _sender() -> None: + sent = 0 + async for s_idx in rt.scheduler: + if rt.stop_requested: + break + rt.issue_sample(s_idx, rt.dataloader, rt.uuid_to_index) + sent += 1 + rt.send_n = sent + rt.send_done = True + + return _sender + + +# ── Receiver ───────────────────────────────────────────────────────────── + + +async def _receiver(rt: _BenchmarkRuntime) -> None: + """Unified receiver: async recv wakeup + sync poll drain.""" + recv_n = 0 + while True: + result = await rt.http_client.recv() + if result is None: + break + if _NO_RECORD: + recv_n += 1 + rt.scheduler.notify_complete() + while rt.http_client.poll() is not None: + recv_n += 1 + rt.scheduler.notify_complete() + else: + rt.handle_response(result) + while (r := rt.http_client.poll()) is not None: + rt.handle_response(r) + if rt.send_done and ( + (_NO_RECORD and recv_n >= rt.send_n) + or (not _NO_RECORD and rt.recorder.n_inflight_samples <= 0) + or rt.stop_requested + ): + break + + +# ── Accuracy phase ─────────────────────────────────────────────────────── + + +async def _run_accuracy_phase( + rt: _BenchmarkRuntime, + acc_ds: Dataset, + acc_scheduler: Scheduler, +) -> None: + """Run one accuracy dataset: max-throughput sender + receiver.""" + acc_uuid_map: dict[str, int] = {} + done = False + + async def sender() -> None: + nonlocal done + sent = 0 + for s_idx, _ in acc_scheduler: + if rt.stop_requested: + break + rt.issue_sample(s_idx, acc_ds, acc_uuid_map) + sent += 1 + if sent % 1000 == 0: + await asyncio.sleep(0) + done = True + + async def receiver() -> None: + while True: + result = await rt.http_client.recv() + if result is None: + break + rt.handle_response(result) + while (r := rt.http_client.poll()) is not None: + rt.handle_response(r) + if done and rt.recorder.n_inflight_samples <= 0: + break + + await asyncio.gather(sender(), receiver()) + + +# ── Resource setup / teardown ──────────────────────────────────────────── + + +def _build_http_config(ctx: BenchmarkContext) -> HTTPClientConfig: + config = ctx.config + api_type: APIType = config.endpoint_config.api_type + return HTTPClientConfig( + endpoint_urls=[ + urljoin(e, api_type.default_route()) + for e in config.endpoint_config.endpoints + ], + api_type=api_type, + num_workers=config.settings.client.workers, + record_worker_events=config.settings.client.record_worker_events, + event_logs_dir=ctx.report_dir, + log_level=config.settings.client.log_level, + cpu_affinity=ctx.affinity_plan, + warmup_connections=config.settings.client.warmup_connections, + max_connections=config.settings.client.max_connections, + api_key=config.endpoint_config.api_key, + ) + + +def _generate_report( + recorder: AsyncEventReporter, + ctx: BenchmarkContext, +) -> Any: + """Read SQLite and generate metrics report.""" + try: + with MetricsReporter( + recorder.connection_name, client_type="sqlite" + ) as reporter: + report = reporter.create_report(ctx.tokenizer) + report.display(fn=print, summary_only=True) + if ctx.report_dir: + report.to_json(save_to=Path(ctx.report_dir) / "result_summary.json") + with open(Path(ctx.report_dir) / "report.txt", "w") as f: + report.display(fn=f.write, summary_only=False, newline="\n") + reporter.dump_to_json(Path(ctx.report_dir) / "events.jsonl") + logger.info(f"Report saved to: {ctx.report_dir}") + return report + except Exception as e: + logger.warning(f"Report generation failed: {e}") + return None + + +def _cleanup( + loop: asyncio.AbstractEventLoop, + pbar: tqdm | None, + recorder: AsyncEventReporter | None, + writer: AsyncEventRecorder | None, + http_client: HTTPEndpointClient | None, + publisher: ZmqEventRecordPublisher | None, + zmq_ctx: ManagedZMQContext | None, + session_ended: bool, +) -> None: + """Best-effort cleanup — each step guarded individually.""" + try: + loop.remove_signal_handler(signal.SIGINT) + except Exception: + pass + + if pbar: + try: + pbar.close() + except Exception: + pass + + if recorder and not session_ended: + try: + recorder.record_event(SessionEventType.ENDED, time.monotonic_ns()) + except Exception: + pass + + if writer: + try: + writer.stop() + except Exception: + pass + + if http_client: + try: + http_client.shutdown() + except Exception: + pass + + try: + os.sched_setaffinity(0, range(os.cpu_count() or 1)) + except (OSError, AttributeError): + pass + os.environ["TOKENIZERS_PARALLELISM"] = "true" + + if publisher: + try: + publisher.close() + except Exception: + pass + + if zmq_ctx: + try: + zmq_ctx.cleanup() + except Exception: + pass + + +# ── Main entry point ───────────────────────────────────────────────────── + + +async def run_benchmark_async( + ctx: BenchmarkContext, +) -> tuple[Any, ResponseCollector]: + """Execute benchmark on a single uvloop — no threads in the main process.""" + loop = asyncio.get_running_loop() + loop.set_task_factory(asyncio.eager_task_factory) # type: ignore[arg-type] + + config = ctx.config + zmq_ctx: ManagedZMQContext | None = None + http_client: HTTPEndpointClient | None = None + publisher: ZmqEventRecordPublisher | None = None + writer: AsyncEventRecorder | None = None + recorder: AsyncEventReporter | None = None + pbar: tqdm | None = None + collector = ResponseCollector(collect_responses=ctx.collect_responses) + session_ended = False + report = None + + try: + # ── Resources ──────────────────────────────────────────────── + zmq_ctx = ManagedZMQContext(io_threads=4) + + logger.info(f"Connecting: {config.endpoint_config.endpoints}") + http_client = await asyncio.to_thread( + HTTPEndpointClient, + _build_http_config(ctx), + loop=loop, + zmq_context=zmq_ctx, + ) + + session_id = f"cli_benchmark_{uuid.uuid4().hex[:8]}" + pub_socket_name = f"ev_pub_{session_id[:8]}" + publisher = ZmqEventRecordPublisher(pub_socket_name, zmq_ctx, loop=loop) + + writer = AsyncEventRecorder( + session_id, publisher.bind_address, sub_settle_s=0.5, stop_timeout=5.0 + ) + writer.start() + + idle_event = asyncio.Event() + recorder = AsyncEventReporter(publisher, session_id, notify_idle=idle_event) + + pbar = tqdm( + desc=f"{config.model_params.name} (Streaming: {ctx.enable_streaming})", + total=ctx.total_samples, + smoothing=0, + ) + collector = ResponseCollector( + collect_responses=ctx.collect_responses, pbar=pbar + ) + + rt = _BenchmarkRuntime( + http_client=http_client, + recorder=recorder, + scheduler=ctx.scheduler, + collector=collector, + dataloader=ctx.dataloader, + ) + + # ── SIGINT ─────────────────────────────────────────────────── + def on_sigint() -> None: + logger.warning("Interrupt received, stopping benchmark...") + rt.stop_requested = True + rt.send_done = True + + loop.add_signal_handler(signal.SIGINT, on_sigint) + + # ── Performance phase ──────────────────────────────────────── + recorder.record_event(SessionEventType.STARTED, time.monotonic_ns()) + + sender_coro = _make_sender(rt, config.settings.load_pattern.type) + logger.info("Running...") + await asyncio.gather(sender_coro(), _receiver(rt)) + loop.remove_signal_handler(signal.SIGINT) + + recorder.record_event( + SessionEventType.STOP_PERFORMANCE_TRACKING, time.monotonic_ns() + ) + logger.info("All performance samples issued") + + # ── Accuracy phases ────────────────────────────────────────── + if ctx.accuracy_datasets and not rt.stop_requested: + for acc_ds in ctx.accuracy_datasets: + ds_name = getattr( + acc_ds.__class__, "DATASET_ID", acc_ds.__class__.__name__ + ) + acc_rt = RuntimeSettings( + metric_target=ctx.rt_settings.metric_target, + reported_metrics=ctx.rt_settings.reported_metrics, + min_duration_ms=0, + max_duration_ms=None, + n_samples_from_dataset=acc_ds.num_samples(), + n_samples_to_issue=acc_ds.num_samples() * acc_ds.repeats, + min_sample_count=acc_ds.num_samples() * acc_ds.repeats, + rng_sched=ctx.rt_settings.rng_sched, + rng_sample_index=ctx.rt_settings.rng_sample_index, + load_pattern=ctx.rt_settings.load_pattern, + ) + acc_sched = ctx.scheduler.__class__( + acc_rt, WithoutReplacementSampleOrder + ) + + logger.info(f"Running accuracy phase: {ds_name}") + await _run_accuracy_phase(rt, acc_ds, acc_sched) + + logger.info("All accuracy samples issued") + + # ── Drain + finalize ───────────────────────────────────────── + recorder.should_check_idle = True + recorder.record_event(SessionEventType.STOP_LOADGEN, time.monotonic_ns()) + + if recorder.n_inflight_samples > 0 and not rt.stop_requested: + try: + await asyncio.wait_for( + idle_event.wait(), + timeout=config.timeout or SystemDefaults.DEFAULT_TIMEOUT, + ) + except TimeoutError: + logger.warning( + f"Timed out waiting for {recorder.n_inflight_samples} inflight samples" + ) + + recorder.record_event(SessionEventType.ENDED, time.monotonic_ns()) + session_ended = True + writer.stop() + + except KeyboardInterrupt: + logger.warning("Benchmark interrupted by user") + except ExecutionError: + raise + except Exception as e: + logger.error(f"Benchmark failed: {e}") + raise ExecutionError(f"Benchmark execution failed: {e}") from e + finally: + _cleanup( + loop, + pbar, + recorder, + writer, + http_client, + publisher, + zmq_ctx, + session_ended, + ) + + # Reset CPU affinity is in _cleanup; generate report here + if recorder: + report = _generate_report(recorder, ctx) + + return (report, collector) diff --git a/src/inference_endpoint/load_generator/scheduler.py b/src/inference_endpoint/load_generator/scheduler.py index ae691d09..c8854b97 100644 --- a/src/inference_endpoint/load_generator/scheduler.py +++ b/src/inference_endpoint/load_generator/scheduler.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import random -import threading from abc import ABC, abstractmethod -from collections.abc import Callable, Iterator +from collections.abc import AsyncIterator, Callable, Iterator from ..config.runtime_settings import RuntimeSettings from ..config.schema import LoadPatternType -from .sample import SampleEvent, SampleEventHandler class SampleOrder(ABC): @@ -280,7 +279,7 @@ def __init__( rng=self.runtime_settings.rng_sample_index, ) ) - self.delay_fn: Callable[[], int] | None = None # Subclasses must set this + self.delay_fn: Callable[[], float] | None = None # Subclasses must set this def __iter__(self): """Iterate over (sample_index, delay_ns) pairs. @@ -290,8 +289,36 @@ def __iter__(self): - sample_index: Index of sample to issue next - delay_ns: Nanoseconds to wait before issuing """ + delay_fn = self.delay_fn + assert delay_fn is not None, "delay_fn must be set by subclass before iteration" for s_idx in self.sample_order: - yield s_idx, self.delay_fn() + yield s_idx, delay_fn() + + async def __aiter__(self) -> AsyncIterator[int]: + """Async iterate over sample indices with precise timing. + + Accumulates absolute target times to avoid drift. Uses asyncio.sleep + for the wait — proven accurate up to 500k QPS. + + Yields: + Sample index to issue next (timing is handled internally). + """ + loop = asyncio.get_running_loop() + target = loop.time() + delay_fn = self.delay_fn + assert delay_fn is not None, "delay_fn must be set by subclass before iteration" + for s_idx in self.sample_order: + delay_ns = delay_fn() + if delay_ns > 0: + target += delay_ns / 1e9 + remaining = target - loop.time() + if remaining > 0: + await asyncio.sleep(remaining) + yield s_idx + + def notify_complete(self) -> None: + """Notify the scheduler that a sample has completed.""" + pass def __init_subclass__(cls, load_pattern: LoadPatternType | None = None, **kwargs): """Auto-register scheduler implementations. @@ -336,14 +363,22 @@ def get_implementation(cls, load_pattern: LoadPatternType) -> type["Scheduler"]: class MaxThroughputScheduler(Scheduler, load_pattern=LoadPatternType.MAX_THROUGHPUT): - """Offline max throughput scheduler (all queries at t=0). + """Offline max throughput scheduler — zero delay, maximum send rate. + + No delay function, no timing logic. Yields sample indices as fast as + the caller can consume them. The caller (benchmark.py sender) controls + yielding to the event loop (sleep(0) every N sends). Auto-registers for LoadPatternType.MAX_THROUGHPUT. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.delay_fn = uniform_delay_fn(rng=self.runtime_settings.rng_sched) + def __iter__(self): + for s_idx in self.sample_order: + yield s_idx, 0 + + async def __aiter__(self) -> AsyncIterator[int]: + for s_idx in self.sample_order: + yield s_idx class PoissonDistributionScheduler(Scheduler, load_pattern=LoadPatternType.POISSON): @@ -384,37 +419,23 @@ def __init__(self, runtime_settings: RuntimeSettings, sample_order_cls): f"target_concurrency must be > 0 for CONCURRENCY load pattern, got {target_concurrency}" ) - # Use threading.Condition for concurrency control with explicit counter - self._condition = threading.Condition() - self._inflight = 0 self._target_concurrency = target_concurrency - - # Register completion hook - free up slot when query completes - SampleEventHandler.register_hook(SampleEvent.COMPLETE, self._release_slot) - - # Unused (required by Scheduler interface) - returns 0 delay + self._semaphore: asyncio.Semaphore | None = None self.delay_fn = lambda: 0 - def _release_slot(self, result=None): - """Release a concurrency slot and notify waiting threads. + def notify_complete(self) -> None: + """Release a concurrency slot so the next sample can be issued.""" + if self._semaphore is not None: + self._semaphore.release() - Args: - result: QueryResult from completed query (unused, required by hook signature) - """ - with self._condition: - self._inflight -= 1 - self._condition.notify() + async def __aiter__(self) -> AsyncIterator[int]: + """Async iterate with concurrency control via asyncio.Semaphore. - def __iter__(self): + Sender acquires a slot before yielding each sample. Receiver calls + notify_complete() which releases the slot after each QueryResult. """ - Iterate over sample indices to issue. - Yields sample indices until total_samples_to_issue is reached. + self._semaphore = asyncio.Semaphore(self._target_concurrency) - Waits for available concurrency slot before yielding each sample index. - """ for s_idx in self.sample_order: - with self._condition: - while self._inflight >= self._target_concurrency: - self._condition.wait() - self._inflight += 1 - yield s_idx, 0 + await self._semaphore.acquire() + yield s_idx diff --git a/src/inference_endpoint/metrics/async_recorder.py b/src/inference_endpoint/metrics/async_recorder.py new file mode 100644 index 00000000..8a46669d --- /dev/null +++ b/src/inference_endpoint/metrics/async_recorder.py @@ -0,0 +1,373 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Background process for SQLite event recording via ZMQ PUB/SUB. + +Replaces the EventRecorder's writer thread with a separate OS process. +Uses the existing ZmqEventRecordSubscriber infrastructure from async_utils. + +The main process publishes EventRecords via ZmqEventRecordPublisher (ZMQ PUB, +non-blocking). This module runs a subscriber in a separate OS process that +receives events and batch-writes them to SQLite at /dev/shm. + +The SQLite schema is identical to EventRecorder's, ensuring full compatibility +with MetricsReporter. A mapping table converts pub-sub EventType topics to +the old event_type string values expected by MetricsReporter. +""" + +from __future__ import annotations + +import asyncio +import logging +import multiprocessing +import multiprocessing.synchronize +import sqlite3 +from pathlib import Path +from urllib.parse import urlparse + +import msgspec.json +import uvloop + +from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext +from inference_endpoint.async_utils.transport.zmq.pubsub import ( + ZmqEventRecordSubscriber, +) +from inference_endpoint.core.record import ( + EventRecord, + SessionEventType, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# EventType topic -> old event_type value for SQLite (MetricsReporter compat) +# --------------------------------------------------------------------------- + +_TOPIC_TO_SQLITE_EVENT_TYPE: dict[str, str] = { + # Session events + "session.started": "test_started", + "session.ended": "test_ended", + "session.stop_loadgen": "loadgen_stop", + "session.start_performance_tracking": "start_performance_tracking", + "session.stop_performance_tracking": "stop_performance_tracking", + # Sample events + "sample.issued": "loadgen_issue_called", + "sample.complete": "complete", + "sample.recv_first": "first_chunk_received", + "sample.recv_non_first": "non_first_chunk_received", + "sample.client_send": "http_request_issued", + "sample.client_resp_done": "http_response_completed", + "sample.transport_sent": "zmq_response_sent", + "sample.transport_recv": "zmq_request_received", + # Error events + "error.generic": "error", + "error.loadgen": "error", + "error.session": "error", + "error.client": "error", +} + +# SQLite schema (same as EventRecorder) +_CREATE_TABLE = ( + "CREATE TABLE IF NOT EXISTS events (" + "sample_uuid TEXT, event_type TEXT, timestamp_ns INTEGER, data BLOB)" +) +_INSERT = "INSERT INTO events (sample_uuid, event_type, timestamp_ns, data) VALUES (?, ?, ?, ?)" + + +# --------------------------------------------------------------------------- +# SQLite Writer Subscriber +# --------------------------------------------------------------------------- + + +class _SqliteWriterSubscriber(ZmqEventRecordSubscriber): + """Subscriber that writes EventRecords to SQLite, compatible with MetricsReporter. + + Runs in a background process with its own uvloop. Processes received + EventRecords by converting them to the SQLite format and batch-inserting. + """ + + def __init__( + self, + db_path: str, + txn_buffer_size: int, + done_event: asyncio.Event, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.db_path = db_path + self.txn_buffer_size = txn_buffer_size + self._done_event = done_event + + # Open SQLite connection + self._conn = sqlite3.connect(db_path) + self._cur = self._conn.cursor() + self._cur.execute(_CREATE_TABLE) + self._conn.commit() + + self._sql_buffer: list[tuple[str, str, int, bytes]] = [] + + def _commit_buffer(self) -> None: + if self._sql_buffer: + self._cur.executemany(_INSERT, self._sql_buffer) + self._conn.commit() + self._sql_buffer.clear() + + async def process(self, records: list[EventRecord]) -> None: + """Process received EventRecords -- convert and buffer for SQLite.""" + for record in records: + topic: str = record.event_type.topic # type: ignore[attr-defined] + sqlite_event_type = _TOPIC_TO_SQLITE_EVENT_TYPE.get(topic) or topic + + data_bytes = b"" + if record.data is not None: + data_bytes = msgspec.json.encode(record.data) + + self._sql_buffer.append( + (record.sample_uuid, sqlite_event_type, record.timestamp_ns, data_bytes) + ) + + # Check for session ended -> signal done after flush + if record.event_type == SessionEventType.ENDED: + self._commit_buffer() + self._done_event.set() + return + + if len(self._sql_buffer) >= self.txn_buffer_size: + self._commit_buffer() + + def close(self) -> None: + """Flush remaining events, create indexes for fast reads, and close.""" + if not self.is_closed: + self._commit_buffer() + # Create indexes after all writes -- speeds up MetricsReporter queries + try: + self._cur.execute( + "CREATE INDEX IF NOT EXISTS idx_event_type ON events(event_type)" + ) + self._cur.execute( + "CREATE INDEX IF NOT EXISTS idx_sample_uuid ON events(sample_uuid)" + ) + self._cur.execute( + "CREATE INDEX IF NOT EXISTS idx_type_uuid ON events(event_type, sample_uuid)" + ) + self._conn.commit() + except Exception: + pass # non-critical -- queries still work, just slower + self._cur.close() + self._conn.close() + super().close() + + +# --------------------------------------------------------------------------- +# Background process entry point +# --------------------------------------------------------------------------- + + +def _subscriber_main( + publisher_address: str, + db_path: str, + txn_buffer_size: int, + ready_event: multiprocessing.synchronize.Event | None = None, +) -> None: + """Entry point for the background subscriber process. + + Creates a uvloop, connects a subscriber to the publisher's address, + and writes events to SQLite until SESSION_ENDED is received. + """ + import signal + + signal.signal(signal.SIGINT, signal.SIG_IGN) # parent controls lifecycle + + async def _run(): + loop = asyncio.get_running_loop() + loop.set_task_factory(asyncio.eager_task_factory) + + done_event = asyncio.Event() + + # Reset ManagedZMQContext singleton inherited from parent via fork(). + # The forked singleton holds a stale ZMQ context that doesn't work + # correctly in the child process. + ManagedZMQContext._instance = None + + # Parse the publisher address to extract socket_dir and path. + # Publisher address format for IPC: "ipc:///" + # urlparse puts the full filesystem path in parsed.path for non-registered + # URI schemes like "ipc" (netloc is empty). + parsed = urlparse(publisher_address) + scheme = parsed.scheme or "ipc" + + if scheme == "ipc": + # parsed.path contains the full filesystem path, e.g. "/tmp/zmq_xxx/ev_pub_abc" + full_path = parsed.path + socket_dir = str(Path(full_path).parent) + socket_name = Path(full_path).name + zmq_ctx = ManagedZMQContext(io_threads=1, socket_dir=socket_dir) + else: + # TCP: path is host:port in netloc + socket_name = parsed.netloc or parsed.path + zmq_ctx = ManagedZMQContext(io_threads=1) + + subscriber = _SqliteWriterSubscriber( + db_path=db_path, + txn_buffer_size=txn_buffer_size, + done_event=done_event, + path=socket_name, + zmq_context=zmq_ctx, + loop=loop, + topics=None, # Subscribe to all topics + scheme=scheme, + ) + + # Start receiving + subscriber.start() + + # Signal readiness to the main process + if ready_event is not None: + ready_event.set() + + # Wait for SESSION_ENDED or timeout + try: + await asyncio.wait_for(done_event.wait(), timeout=3600) + except TimeoutError: + logger.warning("Subscriber timed out waiting for SESSION_ENDED") + finally: + subscriber.close() + zmq_ctx.cleanup() + + uvloop.run(_run()) + + +# --------------------------------------------------------------------------- +# Process manager (used by the main process) +# --------------------------------------------------------------------------- + + +class AsyncEventRecorder: + """Manages a background process that subscribes to EventPublisherService + and writes events to SQLite. + + Usage (context manager): + with AsyncEventRecorder(session_id, publisher.bind_address) as writer: + # ... benchmark runs, publisher publishes events ... + pass + + Usage (manual): + writer = AsyncEventRecorder(session_id, publisher.bind_address) + writer.start() + # ... benchmark runs ... + writer.stop() + """ + + def __init__( + self, + session_id: str, + publisher_address: str, + txn_buffer_size: int = 1000, + sub_settle_s: float = 0.3, + start_timeout: float = 10.0, + stop_timeout: float = 10.0, + ): + self.session_id = session_id + self.publisher_address = publisher_address + self.txn_buffer_size = txn_buffer_size + self.sub_settle_s = sub_settle_s + self.start_timeout = start_timeout + self.stop_timeout = stop_timeout + self._process: multiprocessing.Process | None = None + + @property + def db_path(self) -> Path: + return Path(f"/dev/shm/mlperf_testsession_{self.session_id}.db") + + def __enter__(self) -> AsyncEventRecorder: + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + # On error, use a short timeout -- the subscriber may not have received + # SESSION_ENDED, so a long wait would just delay the inevitable kill. + timeout = self.stop_timeout if exc_type is None else 2.0 + self.stop(timeout=timeout) + + def start( + self, timeout: float | None = None, sub_settle_s: float | None = None + ) -> None: + """Start the background subscriber process. + + Blocks until the subscriber signals readiness (connected and subscribed), + then waits for ZMQ PUB/SUB subscription to propagate. + + NOTE: This uses blocking waits (not async) because ZMQ's PUB/SUB handshake + requires the I/O threads to run uninterrupted. Called once during setup, + not on the hot path. + + Args: + timeout: Max seconds to wait for subscriber readiness (default: self.start_timeout). + sub_settle_s: Extra blocking sleep for ZMQ subscription propagation (default: self.sub_settle_s). + + Raises: + TimeoutError: If subscriber doesn't become ready in time. + """ + import time + + timeout = timeout if timeout is not None else self.start_timeout + sub_settle_s = sub_settle_s if sub_settle_s is not None else self.sub_settle_s + + ready_event = multiprocessing.Event() + self._process = multiprocessing.Process( + target=_subscriber_main, + args=( + self.publisher_address, + str(self.db_path), + self.txn_buffer_size, + ready_event, + ), + daemon=True, + name=f"EventWriter-{self.session_id[:8]}", + ) + self._process.start() + + # Blocking wait for subprocess to connect and subscribe + if not ready_event.wait(timeout=timeout): + if self._process.is_alive(): + self._process.kill() + raise TimeoutError( + f"AsyncEventRecorder did not become ready within {timeout}s" + ) + + # Blocking sleep for ZMQ PUB/SUB subscription propagation + time.sleep(sub_settle_s) + + logger.debug( + f"AsyncEventRecorder started (pid={self._process.pid}, " + f"db={self.db_path})" + ) + + def stop(self, timeout: float | None = None) -> None: + """Wait for the background process to finish (it stops on SESSION_ENDED).""" + timeout = timeout if timeout is not None else self.stop_timeout + if self._process is not None and self._process.is_alive(): + self._process.join(timeout=timeout) + if self._process.is_alive(): + logger.warning("AsyncEventRecorder did not stop, killing") + self._process.kill() + self._process.join(timeout=2.0) + logger.debug("AsyncEventRecorder stopped") + + @property + def is_alive(self) -> bool: + return self._process is not None and self._process.is_alive() diff --git a/src/inference_endpoint/metrics/async_reporter.py b/src/inference_endpoint/metrics/async_reporter.py new file mode 100644 index 00000000..ce5664a8 --- /dev/null +++ b/src/inference_endpoint/metrics/async_reporter.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lightweight event reporter for the async benchmark runtime. + +Publishes EventRecords via ZmqEventRecordPublisher (sync ZMQ PUB NOBLOCK). +Tracks inflight samples and signals idle when all samples are complete. +""" + +from __future__ import annotations + +import asyncio +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from inference_endpoint.core.record import ( + EventRecord, + EventType, + SampleEventType, +) +from inference_endpoint.core.types import ErrorData, PromptData, TextModelOutput + +if TYPE_CHECKING: + from inference_endpoint.async_utils.transport.zmq.pubsub import ( + ZmqEventRecordPublisher, + ) + +logger = logging.getLogger(__name__) + + +class AsyncEventReporter: + """Event reporter for the async benchmark runtime. + + Publishes EventRecords individually via ZmqEventRecordPublisher + (sync ZMQ PUB NOBLOCK). Tracks inflight samples and signals idle + when all complete. + + Usage:: + + publisher = ZmqEventRecordPublisher(addr, zmq_ctx, loop=loop) + reporter = AsyncEventReporter(publisher, session_id) + reporter.record_event(SessionEventType.STARTED, time.monotonic_ns()) + reporter.record_event(SampleEventType.ISSUED, ts, sample_uuid=uid) + reporter.record_event(SampleEventType.COMPLETE, ts, sample_uuid=uid, data=output) + reporter.record_event(SessionEventType.ENDED, time.monotonic_ns()) + """ + + __slots__ = ( + "_publisher", + "session_id", + "n_inflight_samples", + "should_check_idle", + "notify_idle", + ) + + def __init__( + self, + publisher: ZmqEventRecordPublisher, + session_id: str, + notify_idle: asyncio.Event | None = None, + ): + from inference_endpoint.async_utils.transport.zmq.pubsub import ( + ZmqEventRecordPublisher as _Pub, + ) + + self._publisher: _Pub = publisher + self.session_id = session_id + self.n_inflight_samples: int = 0 + self.should_check_idle: bool = False + self.notify_idle = notify_idle + + @property + def connection_name(self) -> Path: + """SQLite database path (same convention as EventRecorder).""" + return Path(f"/dev/shm/mlperf_testsession_{self.session_id}.db") + + def record_event( + self, + ev_type: EventType, + timestamp_ns: int, + sample_uuid: str = "", + data: Any = None, + ) -> None: + """Record an event by publishing it via ZMQ. + + Args: + ev_type: EventType (SessionEventType, SampleEventType, ErrorEventType). + timestamp_ns: Monotonic nanosecond timestamp. + sample_uuid: UUID of the sample (empty for session-level events). + data: Optional event data. TextModelOutput, PromptData, and ErrorData + are passed through; other types are dropped (data=None). + """ + if ev_type == SampleEventType.ISSUED: + self.n_inflight_samples += 1 + elif ev_type == SampleEventType.COMPLETE: + self.n_inflight_samples -= 1 + + # EventRecord.data accepts OUTPUT_TYPE | PromptData | ErrorData | None. + # Pass through recognized types; drop anything else. + event_data: TextModelOutput | PromptData | ErrorData | None + if data is None or isinstance(data, TextModelOutput | PromptData | ErrorData): + event_data = data + else: + event_data = None + + record = EventRecord( + event_type=ev_type, + timestamp_ns=timestamp_ns, + sample_uuid=sample_uuid, + data=event_data, + ) + self._publisher.publish(record) + + if ( + self.should_check_idle + and self.notify_idle is not None + and self.n_inflight_samples == 0 + ): + self.notify_idle.set() diff --git a/src/inference_endpoint/utils/benchmark_endpoints.py b/src/inference_endpoint/utils/benchmark_endpoints.py new file mode 100644 index 00000000..21dbbec3 --- /dev/null +++ b/src/inference_endpoint/utils/benchmark_endpoints.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +End-to-end benchmark runner. + +Launches a MaxThroughputServer and runs the full benchmark pipeline +to validate the runtime works correctly. Tests offline, poisson, and +concurrency modes with either the async or sync (threaded) runtime. + +Usage:: + + # Async path (default) — offline only + python -m inference_endpoint.utils.benchmark_endpoints + + # Sync (threaded) path — offline only + python -m inference_endpoint.utils.benchmark_endpoints --sync + + # All three load patterns, both paths + python -m inference_endpoint.utils.benchmark_endpoints --all + python -m inference_endpoint.utils.benchmark_endpoints --all --sync + + # Against an external endpoint + python -m inference_endpoint.utils.benchmark_endpoints --endpoint http://host:8080 + + # Custom parameters + python -m inference_endpoint.utils.benchmark_endpoints --samples 1000 --target-qps 500 -w 8 + + # Streaming mode + python -m inference_endpoint.utils.benchmark_endpoints --stream --all +""" + +from __future__ import annotations + +import argparse +import os +import sys +import tempfile +import time + +os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") + + +def _build_config( + endpoint_url: str, + *, + mode: str, + model: str, + dataset_path: str, + samples: int, + target_qps: float, + target_concurrency: int, + workers: int, + streaming: bool, + report_dir: str, + timeout: float, +): + """Build a BenchmarkConfig for the given mode.""" + from inference_endpoint.config.schema import ( + ClientSettings, + Dataset, + EndpointConfig, + LoadPattern, + LoadPatternType, + ModelParams, + OfflineBenchmarkConfig, + OnlineBenchmarkConfig, + OnlineSettings, + StreamingMode, + ) + + endpoint_config = EndpointConfig(endpoints=[endpoint_url]) + model_params = ModelParams( + name=model, + streaming=StreamingMode.ON if streaming else StreamingMode.OFF, + ) + client = ClientSettings(workers=workers) + perf_dataset = Dataset( + path=dataset_path, + samples=samples, + parser={"prompt": "text_input"}, + ) + + if mode == "offline": + from inference_endpoint.config.schema import OfflineSettings + + return OfflineBenchmarkConfig( + endpoint_config=endpoint_config, + model_params=model_params, + datasets=[perf_dataset], + report_dir=report_dir, + timeout=timeout, + settings=OfflineSettings(client=client), + ) + + if mode == "poisson": + load_pattern = LoadPattern(type=LoadPatternType.POISSON, target_qps=target_qps) + elif mode == "concurrency": + load_pattern = LoadPattern( + type=LoadPatternType.CONCURRENCY, + target_concurrency=target_concurrency, + ) + else: + raise ValueError(f"Unknown mode: {mode}") + + return OnlineBenchmarkConfig( + endpoint_config=endpoint_config, + model_params=model_params, + datasets=[perf_dataset], + report_dir=report_dir, + timeout=timeout, + settings=OnlineSettings(load_pattern=load_pattern, client=client), + ) + + +def _run_async(config, test_mode) -> None: + """Run benchmark via the async runtime (single uvloop, no threads).""" + from inference_endpoint.async_utils.runner import run_async + from inference_endpoint.commands.benchmark.execute import ( + finalize_benchmark, + setup_benchmark, + ) + from inference_endpoint.commands.benchmark.execute_async import ( + run_benchmark_async, + ) + + ctx = setup_benchmark(config, test_mode) + report, collector = run_async(run_benchmark_async(ctx)) + finalize_benchmark(ctx, report, collector) + + +def _run_sync(config, test_mode) -> None: + """Run benchmark via the old threaded runtime (BenchmarkSession + EventRecorder).""" + from inference_endpoint.commands.benchmark.execute import ( + finalize_benchmark, + run_benchmark_threaded, + setup_benchmark, + ) + + ctx = setup_benchmark(config, test_mode) + report, collector = run_benchmark_threaded(ctx) + finalize_benchmark(ctx, report, collector) + + +def _run_one( + endpoint_url: str, + mode: str, + *, + sync: bool, + model: str, + dataset_path: str, + samples: int, + target_qps: float, + target_concurrency: int, + workers: int, + streaming: bool, + timeout: float, +) -> bool: + """Run a single benchmark mode. Returns True on success.""" + from inference_endpoint.config.schema import TestMode + + runner_name = "sync" if sync else "async" + + with tempfile.TemporaryDirectory( + prefix=f"bench_{mode}_{runner_name}_" + ) as report_dir: + try: + config = _build_config( + endpoint_url, + mode=mode, + model=model, + dataset_path=dataset_path, + samples=samples, + target_qps=target_qps, + target_concurrency=target_concurrency, + workers=workers, + streaming=streaming, + report_dir=report_dir, + timeout=timeout, + ) + if sync: + _run_sync(config, TestMode.PERF) + else: + _run_async(config, TestMode.PERF) + return True + except Exception as e: + print(f" [{mode}/{runner_name}] FAIL: {e}", file=sys.stderr) + return False + + +def main(): + parser = argparse.ArgumentParser( + description="E2E benchmark smoke test (async & sync runtimes)" + ) + parser.add_argument( + "--endpoint", + default=None, + help="External endpoint URL. If not set, launches MaxThroughputServer.", + ) + parser.add_argument("--model", default="max-tp") + parser.add_argument( + "--dataset", + default=None, + help="Dataset path. Default: tests/datasets/dummy_1k.pkl", + ) + parser.add_argument("--samples", type=int, default=100) + parser.add_argument("--target-qps", type=float, default=100.0) + parser.add_argument("--target-concurrency", type=int, default=8) + parser.add_argument("-w", "--workers", type=int, default=2) + parser.add_argument("--server-workers", type=int, default=2) + parser.add_argument("--stream", action="store_true") + parser.add_argument("--timeout", type=float, default=60.0) + parser.add_argument( + "--all", + action="store_true", + help="Run all three modes (offline, poisson, concurrency)", + ) + parser.add_argument( + "--sync", + action="store_true", + help="Use the old threaded runtime (BenchmarkSession) instead of async", + ) + args = parser.parse_args() + + dataset_path = args.dataset + if dataset_path is None: + default_path = "tests/datasets/dummy_1k.pkl" + if os.path.exists(default_path): + dataset_path = default_path + else: + print(f"Default dataset not found: {default_path}", file=sys.stderr) + sys.exit(1) + + modes = ["offline", "poisson", "concurrency"] if args.all else ["offline"] + runner_name = "sync" if args.sync else "async" + + server = None + if args.endpoint: + endpoint_url = args.endpoint + print(f"Using external endpoint: {endpoint_url}") + else: + from inference_endpoint.testing.max_throughput_server import ( + MaxThroughputServer, + ) + + server = MaxThroughputServer( + port=0, + num_workers=args.server_workers, + stream=args.stream, + quiet=True, + ) + server.start() + endpoint_url = f"{server.url}/v1/chat/completions" + print(f"MaxThroughputServer @ {server.url}") + + print(f"Runtime: {runner_name}") + + results: dict[str, bool] = {} + try: + for mode in modes: + label = f"{mode}/{runner_name}" + t0 = time.monotonic() + ok = _run_one( + endpoint_url, + mode, + sync=args.sync, + model=args.model, + dataset_path=dataset_path, + samples=args.samples, + target_qps=args.target_qps, + target_concurrency=args.target_concurrency, + workers=args.workers, + streaming=args.stream, + timeout=args.timeout, + ) + elapsed = time.monotonic() - t0 + status = "PASS" if ok else "FAIL" + results[label] = ok + print(f" {label}: {status} ({elapsed:.1f}s)") + finally: + if server: + server.stop() + + # Summary + print("\n--- Results ---") + all_ok = True + for label, ok in results.items(): + status = "PASS" if ok else "FAIL" + print(f" {label}: {status}") + if not ok: + all_ok = False + + sys.exit(0 if all_ok else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/load_generator/test_scheduler.py b/tests/unit/load_generator/test_scheduler.py index 05de0600..c798f7a7 100644 --- a/tests/unit/load_generator/test_scheduler.py +++ b/tests/unit/load_generator/test_scheduler.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import math import random -import threading import pytest -from inference_endpoint.load_generator.sample import SampleEventHandler from inference_endpoint.load_generator.scheduler import ( ConcurrencyScheduler, MaxThroughputScheduler, @@ -86,107 +85,48 @@ def test_max_throughput_scheduler(max_throughput_runtime_settings): ], "Order does not match expected deterministic order" +@pytest.mark.unit +@pytest.mark.asyncio @pytest.mark.parametrize("target_concurrency", [1, 2, 100, 1000], indirect=True) -def test_concurrency_scheduler(concurrency_runtime_settings, target_concurrency): - """Test ConcurrencyScheduler properly gates issuance by completions.""" +async def test_concurrency_scheduler(concurrency_runtime_settings, target_concurrency): + """Test ConcurrencyScheduler properly gates issuance by completions (async).""" total_samples = concurrency_runtime_settings.n_samples_to_issue scheduler = ConcurrencyScheduler( concurrency_runtime_settings, WithReplacementSampleOrder ) - # State tracking - state_lock = threading.RLock() issued_count = 0 - completed_count = 0 - current_inflight = 0 max_inflight = 0 + current_inflight = 0 - # Synchronization: signal when queries can complete and when they're done - can_complete = [threading.Event() for _ in range(total_samples)] - completed = [threading.Event() for _ in range(total_samples)] - # Signal when each query is issued - issued = [threading.Event() for _ in range(total_samples)] - - def completion_worker(): - """Waits for signals to complete queries.""" - nonlocal completed_count, current_inflight - - for position in range(total_samples): - can_complete[position].wait() - - with state_lock: - completed_count += 1 - current_inflight -= 1 - assert current_inflight >= 0, "Inflight count went negative" - - scheduler._release_slot() - completed[position].set() - - threading.Thread(target=completion_worker, daemon=True).start() - - def issue_worker(): - """Issues queries through scheduler.""" + async def sender(): nonlocal issued_count, current_inflight, max_inflight + async for _s_idx in scheduler: + issued_count += 1 + current_inflight += 1 + max_inflight = max(max_inflight, current_inflight) + assert ( + current_inflight <= target_concurrency + ), f"Concurrency {current_inflight} exceeded limit {target_concurrency}" + + async def completer(): + """Simulates completions by calling notify_complete after yielding.""" + nonlocal current_inflight + completed = 0 + while completed < total_samples: + if current_inflight > 0: + current_inflight -= 1 + scheduler.notify_complete() + completed += 1 + else: + await asyncio.sleep(0) - for position, _ in enumerate(scheduler): - with state_lock: - issued_count += 1 - current_inflight += 1 - max_inflight = max(max_inflight, current_inflight) - assert ( - current_inflight <= target_concurrency - ), f"Concurrency {current_inflight} exceeded limit {target_concurrency}" - issued[position].set() - - issue_thread = threading.Thread(target=issue_worker, daemon=True) - issue_thread.start() - - try: - # Phase 1: First target_concurrency queries issue immediately - for position in range(target_concurrency): - issued[position].wait() - - with state_lock: - assert issued_count == target_concurrency - assert completed_count == 0 - assert current_inflight == target_concurrency - - # Phase 2: Verify scheduler blocks when at capacity, unblocks on completion - for position in range(target_concurrency, total_samples): - position_to_complete = position - target_concurrency - - # Verify next query hasn't issued yet (scheduler is blocking) - assert not issued[ - position - ].is_set(), f"Query {position} issued before slot was freed" - - # Free a slot - can_complete[position_to_complete].set() - completed[position_to_complete].wait() - - # Verify next query now issues - issued[position].wait() - - with state_lock: - assert current_inflight == target_concurrency - - # Phase 3: Complete remaining queries and cleanup - for position in range(target_concurrency, total_samples): - can_complete[position].set() - completed[position].wait() - - issue_thread.join() - - # Final validation - with state_lock: - assert issued_count == total_samples - assert completed_count == total_samples - assert current_inflight == 0 - assert max_inflight == target_concurrency + await asyncio.gather(sender(), completer()) - finally: - SampleEventHandler.clear_hooks() + assert issued_count == total_samples + assert current_inflight == 0 + assert max_inflight == target_concurrency @pytest.mark.parametrize("target_qps", [50.0, 100.0, 500.0, 1000.0], indirect=True)