Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/inference_endpoint/load_generator/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,38 @@ def _run_test(
data=get_version_info(),
)

# Note: stop_requested is checked after each iteration, so one
# additional sample may be issued after the flag is set. This is
# acceptable — the alternative (checking before next()) would
# require breaking the generator protocol.
for _ in perf_test_generator:
# Actual issue is done during next(generator). Nothing else to do here, just pass.
pass
if self.stop_requested:
self.logger.info(
"Early stop requested, aborting sample issuance"
)
break

EventRecorder.record_event(
SessionEvent.STOP_PERFORMANCE_TRACKING, time.monotonic_ns()
)
self.logger.info("All performance samples issued")
if self.stop_requested:
self.logger.info("Performance sample issuance aborted early")
else:
self.logger.info("All performance samples issued")

if accuracy_test_generators:
if accuracy_test_generators and not self.stop_requested:
for _, generator in accuracy_test_generators.items():
for _ in generator:
# Actual issue is done during next(generator). Nothing else to do here, just pass.
pass
if self.stop_requested:
break
if self.stop_requested:
break

self.logger.info("All accuracy samples issued")
if accuracy_test_generators:
if self.stop_requested:
self.logger.info("Accuracy sample issuance aborted early")
else:
self.logger.info("All accuracy samples issued")

self.event_recorder.should_check_idle = True
EventRecorder.record_event(
Expand Down
81 changes: 69 additions & 12 deletions src/inference_endpoint/metrics/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sqlite3
from collections import defaultdict
from collections.abc import Callable, Iterable
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand All @@ -40,6 +41,43 @@
from transformers import Tokenizer


def _parallel_batch_tokenize(tokenizer: Tokenizer, texts: list[str]) -> list[int]:
"""Batch-tokenize texts using all available cores and return token counts.

Uses a ThreadPoolExecutor to parallelize across ~95% of CPU cores.
HuggingFace tokenizers use a Rust backend that releases the GIL,
so threads achieve real parallelism without GIL contention.
A single tokenizer instance is shared across threads — this is safe for
PreTrainedTokenizerFast (Rust-backed, thread-safe by design).
"""

try:
n_cores = len(os.sched_getaffinity(0))
except (AttributeError, NotImplementedError):
n_cores = os.cpu_count() or 1
n_workers = max(1, int(n_cores * 0.95))

is_fast = getattr(tokenizer, "is_fast", False)
if n_workers <= 1 or len(texts) <= n_workers or not is_fast:
# Few texts, single core, or non-Fast tokenizer — tokenize directly
encoded = tokenizer(texts, add_special_tokens=False)
return [len(ids) for ids in encoded["input_ids"]]

# Split texts into chunks, one per worker
chunk_size = (len(texts) + n_workers - 1) // n_workers
chunks = [texts[i : i + chunk_size] for i in range(0, len(texts), chunk_size)]

def _tokenize_chunk(chunk: list[str]) -> list[int]:
encoded = tokenizer(chunk, add_special_tokens=False)
return [len(ids) for ids in encoded["input_ids"]]

results: list[int] = []
with ThreadPoolExecutor(max_workers=min(n_workers, len(chunks))) as pool:
for chunk_lengths in pool.map(_tokenize_chunk, chunks):
results.extend(chunk_lengths)
return results


class TPOTReportingMode(str, Enum):
"""TPOT (Time Per Output Token) reporting mode.

Expand Down Expand Up @@ -1036,7 +1074,9 @@ def get_output_sequence_lengths(
"""
query_result = self.get_sample_outputs()

rows = []
# Collect all texts for batch tokenization
uuids: list[str] = []
texts: list[str] = []
for sample_uuid, data_bytes in query_result:
output_sequence, reasoning_sequence = output_sequence_from_data(data_bytes)

Expand All @@ -1049,13 +1089,16 @@ def get_output_sequence_lengths(
else:
full_sequence = output_sequence

# Tokenize and calculate length
output_tokens = tokenizer.tokenize(full_sequence)
rows.append((sample_uuid, len(output_tokens)))
uuids.append(sample_uuid)
texts.append(full_sequence)

if not rows:
if not texts:
return None

# Parallel batch tokenize across ~95% of cores
token_counts = _parallel_batch_tokenize(tokenizer, texts)
rows = list(zip(uuids, token_counts, strict=True))

return RollupQueryTable("output_sequence_length", None, rows)

@profile
Expand Down Expand Up @@ -1105,11 +1148,9 @@ def derive_TPOT(
if not query_result:
return None

rows = []
if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED:
repeats = []
else:
repeats = None
# Pass 1: Collect all non-first-chunk texts for batch tokenization
batch_uuids: list[str] = []
batch_texts: list[str] = []

for sample_uuid, data_bytes in query_result:
if data_bytes is None or len(data_bytes) == 0:
Expand Down Expand Up @@ -1157,9 +1198,25 @@ def derive_TPOT(
# Possible malformed output data where empty string is included as a non-first chunk
continue

non_first_tokens = tokenizer.tokenize(non_first_chunk)
n_non_first_tokens = len(non_first_tokens)
batch_uuids.append(sample_uuid)
batch_texts.append(non_first_chunk)

if not batch_texts:
return None

# Parallel batch tokenize across ~95% of cores
token_counts = _parallel_batch_tokenize(tokenizer, batch_texts)

# Pass 2: Compute TPOT using batch-tokenized results
rows = []
if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED:
repeats = []
else:
repeats = None

for sample_uuid, n_non_first_tokens in zip(
batch_uuids, token_counts, strict=True
):
latency = sample_latency_rollup.filter_uuid(sample_uuid, only_first=True)
if latency is None:
raise SampleUUIDNotFoundError(sample_uuid, "events record")
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,11 @@ class CharacterTokenizer:
def tokenize(self, text: str) -> list[str]:
return list(text)

def __call__(
self, texts: list[str], **kwargs: object
) -> dict[str, list[list[int]]]:
return {"input_ids": [list(range(len(t))) for t in texts]}


@pytest.fixture
def tokenizer():
Expand Down
5 changes: 5 additions & 0 deletions tests/performance/test_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class CharTokenizer:
def tokenize(self, text: str) -> list[str]:
return list(text)

def __call__(
self, texts: list[str], **kwargs: object
) -> dict[str, list[list[int]]]:
return {"input_ids": [list(range(len(t))) for t in texts]}


def time_fn(fn, *args, **kwargs):
start_time = time.monotonic_ns()
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/metrics/test_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import json
import math
import os

import msgspec.json
import pytest
Expand All @@ -25,6 +26,7 @@
MetricsReporter,
RollupQueryTable,
TPOTReportingMode,
_parallel_batch_tokenize,
output_sequence_from_data,
)

Expand Down Expand Up @@ -1178,3 +1180,21 @@ def test_empty_data(self):
output, reasoning = output_sequence_from_data(b"")
assert output is None
assert reasoning is None


@pytest.mark.unit
def test_parallel_batch_tokenize_threaded_path(tokenizer, monkeypatch):
"""Exercise the threaded branch of _parallel_batch_tokenize.

Monkeypatches os.sched_getaffinity to return 4 CPUs so the threaded path
triggers with a modest number of texts, and verifies ordering and counts.
"""
# Force 4 CPUs so n_workers=3, then provide 5 texts to exceed the
# direct-tokenize threshold and exercise the threaded chunking path.
monkeypatch.setattr(
os, "sched_getaffinity", lambda _pid: {0, 1, 2, 3}, raising=False
)
texts = ["hello", "ab", "xyz", "a", "test!"]
result = _parallel_batch_tokenize(tokenizer, texts)
# CharacterTokenizer returns len(text) as token count
assert result == [5, 2, 3, 1, 5]
Loading