diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index da5faf84..747742d6 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -92,22 +92,38 @@ def _run_test( data=get_version_info(), ) + # Note: stop_requested is checked after each iteration, so one + # additional sample may be issued after the flag is set. This is + # acceptable — the alternative (checking before next()) would + # require breaking the generator protocol. for _ in perf_test_generator: - # Actual issue is done during next(generator). Nothing else to do here, just pass. - pass + if self.stop_requested: + self.logger.info( + "Early stop requested, aborting sample issuance" + ) + break EventRecorder.record_event( SessionEvent.STOP_PERFORMANCE_TRACKING, time.monotonic_ns() ) - self.logger.info("All performance samples issued") + if self.stop_requested: + self.logger.info("Performance sample issuance aborted early") + else: + self.logger.info("All performance samples issued") - if accuracy_test_generators: + if accuracy_test_generators and not self.stop_requested: for _, generator in accuracy_test_generators.items(): for _ in generator: - # Actual issue is done during next(generator). Nothing else to do here, just pass. - pass + if self.stop_requested: + break + if self.stop_requested: + break - self.logger.info("All accuracy samples issued") + if accuracy_test_generators: + if self.stop_requested: + self.logger.info("Accuracy sample issuance aborted early") + else: + self.logger.info("All accuracy samples issued") self.event_recorder.should_check_idle = True EventRecorder.record_event( diff --git a/src/inference_endpoint/metrics/reporter.py b/src/inference_endpoint/metrics/reporter.py index d6273496..fda8b594 100644 --- a/src/inference_endpoint/metrics/reporter.py +++ b/src/inference_endpoint/metrics/reporter.py @@ -25,6 +25,7 @@ import sqlite3 from collections import defaultdict from collections.abc import Callable, Iterable +from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any @@ -40,6 +41,43 @@ from transformers import Tokenizer +def _parallel_batch_tokenize(tokenizer: Tokenizer, texts: list[str]) -> list[int]: + """Batch-tokenize texts using all available cores and return token counts. + + Uses a ThreadPoolExecutor to parallelize across ~95% of CPU cores. + HuggingFace tokenizers use a Rust backend that releases the GIL, + so threads achieve real parallelism without GIL contention. + A single tokenizer instance is shared across threads — this is safe for + PreTrainedTokenizerFast (Rust-backed, thread-safe by design). + """ + + try: + n_cores = len(os.sched_getaffinity(0)) + except (AttributeError, NotImplementedError): + n_cores = os.cpu_count() or 1 + n_workers = max(1, int(n_cores * 0.95)) + + is_fast = getattr(tokenizer, "is_fast", False) + if n_workers <= 1 or len(texts) <= n_workers or not is_fast: + # Few texts, single core, or non-Fast tokenizer — tokenize directly + encoded = tokenizer(texts, add_special_tokens=False) + return [len(ids) for ids in encoded["input_ids"]] + + # Split texts into chunks, one per worker + chunk_size = (len(texts) + n_workers - 1) // n_workers + chunks = [texts[i : i + chunk_size] for i in range(0, len(texts), chunk_size)] + + def _tokenize_chunk(chunk: list[str]) -> list[int]: + encoded = tokenizer(chunk, add_special_tokens=False) + return [len(ids) for ids in encoded["input_ids"]] + + results: list[int] = [] + with ThreadPoolExecutor(max_workers=min(n_workers, len(chunks))) as pool: + for chunk_lengths in pool.map(_tokenize_chunk, chunks): + results.extend(chunk_lengths) + return results + + class TPOTReportingMode(str, Enum): """TPOT (Time Per Output Token) reporting mode. @@ -1036,7 +1074,9 @@ def get_output_sequence_lengths( """ query_result = self.get_sample_outputs() - rows = [] + # Collect all texts for batch tokenization + uuids: list[str] = [] + texts: list[str] = [] for sample_uuid, data_bytes in query_result: output_sequence, reasoning_sequence = output_sequence_from_data(data_bytes) @@ -1049,13 +1089,16 @@ def get_output_sequence_lengths( else: full_sequence = output_sequence - # Tokenize and calculate length - output_tokens = tokenizer.tokenize(full_sequence) - rows.append((sample_uuid, len(output_tokens))) + uuids.append(sample_uuid) + texts.append(full_sequence) - if not rows: + if not texts: return None + # Parallel batch tokenize across ~95% of cores + token_counts = _parallel_batch_tokenize(tokenizer, texts) + rows = list(zip(uuids, token_counts, strict=True)) + return RollupQueryTable("output_sequence_length", None, rows) @profile @@ -1105,11 +1148,9 @@ def derive_TPOT( if not query_result: return None - rows = [] - if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED: - repeats = [] - else: - repeats = None + # Pass 1: Collect all non-first-chunk texts for batch tokenization + batch_uuids: list[str] = [] + batch_texts: list[str] = [] for sample_uuid, data_bytes in query_result: if data_bytes is None or len(data_bytes) == 0: @@ -1157,9 +1198,25 @@ def derive_TPOT( # Possible malformed output data where empty string is included as a non-first chunk continue - non_first_tokens = tokenizer.tokenize(non_first_chunk) - n_non_first_tokens = len(non_first_tokens) + batch_uuids.append(sample_uuid) + batch_texts.append(non_first_chunk) + + if not batch_texts: + return None + + # Parallel batch tokenize across ~95% of cores + token_counts = _parallel_batch_tokenize(tokenizer, batch_texts) + + # Pass 2: Compute TPOT using batch-tokenized results + rows = [] + if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED: + repeats = [] + else: + repeats = None + for sample_uuid, n_non_first_tokens in zip( + batch_uuids, token_counts, strict=True + ): latency = sample_latency_rollup.filter_uuid(sample_uuid, only_first=True) if latency is None: raise SampleUUIDNotFoundError(sample_uuid, "events record") diff --git a/tests/conftest.py b/tests/conftest.py index 0f5bcfbc..9333ea42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -308,6 +308,11 @@ class CharacterTokenizer: def tokenize(self, text: str) -> list[str]: return list(text) + def __call__( + self, texts: list[str], **kwargs: object + ) -> dict[str, list[list[int]]]: + return {"input_ids": [list(range(len(t))) for t in texts]} + @pytest.fixture def tokenizer(): diff --git a/tests/performance/test_reporter.py b/tests/performance/test_reporter.py index 29dcebbd..c4fbd066 100644 --- a/tests/performance/test_reporter.py +++ b/tests/performance/test_reporter.py @@ -33,6 +33,11 @@ class CharTokenizer: def tokenize(self, text: str) -> list[str]: return list(text) + def __call__( + self, texts: list[str], **kwargs: object + ) -> dict[str, list[list[int]]]: + return {"input_ids": [list(range(len(t))) for t in texts]} + def time_fn(fn, *args, **kwargs): start_time = time.monotonic_ns() diff --git a/tests/unit/metrics/test_reporter.py b/tests/unit/metrics/test_reporter.py index 92416de8..e04d6f03 100644 --- a/tests/unit/metrics/test_reporter.py +++ b/tests/unit/metrics/test_reporter.py @@ -15,6 +15,7 @@ import json import math +import os import msgspec.json import pytest @@ -25,6 +26,7 @@ MetricsReporter, RollupQueryTable, TPOTReportingMode, + _parallel_batch_tokenize, output_sequence_from_data, ) @@ -1178,3 +1180,21 @@ def test_empty_data(self): output, reasoning = output_sequence_from_data(b"") assert output is None assert reasoning is None + + +@pytest.mark.unit +def test_parallel_batch_tokenize_threaded_path(tokenizer, monkeypatch): + """Exercise the threaded branch of _parallel_batch_tokenize. + + Monkeypatches os.sched_getaffinity to return 4 CPUs so the threaded path + triggers with a modest number of texts, and verifies ordering and counts. + """ + # Force 4 CPUs so n_workers=3, then provide 5 texts to exceed the + # direct-tokenize threshold and exercise the threaded chunking path. + monkeypatch.setattr( + os, "sched_getaffinity", lambda _pid: {0, 1, 2, 3}, raising=False + ) + texts = ["hello", "ab", "xyz", "a", "test!"] + result = _parallel_batch_tokenize(tokenizer, texts) + # CharacterTokenizer returns len(text) as token count + assert result == [5, 2, 3, 1, 5]