Skip to content

Commit 1143a57

Browse files
committed
add parallel tokenizer
1 parent 86cbf86 commit 1143a57

File tree

3 files changed

+78
-17
lines changed

3 files changed

+78
-17
lines changed

src/inference_endpoint/load_generator/session.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,22 @@ def _run_test(
9393
)
9494

9595
for _ in perf_test_generator:
96-
# Actual issue is done during next(generator). Nothing else to do here, just pass.
97-
pass
96+
if self.stop_requested:
97+
self.logger.info(
98+
"Early stop requested, aborting sample issuance"
99+
)
100+
break
98101

99102
EventRecorder.record_event(
100103
SessionEvent.STOP_PERFORMANCE_TRACKING, time.monotonic_ns()
101104
)
102105
self.logger.info("All performance samples issued")
103106

104-
if accuracy_test_generators:
107+
if accuracy_test_generators and not self.stop_requested:
105108
for _, generator in accuracy_test_generators.items():
106109
for _ in generator:
107-
# Actual issue is done during next(generator). Nothing else to do here, just pass.
108-
pass
110+
if self.stop_requested:
111+
break
109112

110113
self.logger.info("All accuracy samples issued")
111114

src/inference_endpoint/metrics/reporter.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,40 @@
3939
if TYPE_CHECKING:
4040
from transformers import Tokenizer
4141

42+
logger = logging.getLogger(__name__)
43+
44+
45+
def _parallel_batch_tokenize(tokenizer: Tokenizer, texts: list[str]) -> list[int]:
46+
"""Batch-tokenize texts using all available cores and return token counts.
47+
48+
Uses a ThreadPoolExecutor to parallelize across ~95% of CPU cores.
49+
HuggingFace tokenizers use a Rust backend that releases the GIL,
50+
so threads achieve real parallelism without GIL contention.
51+
"""
52+
from concurrent.futures import ThreadPoolExecutor
53+
54+
n_cores = os.cpu_count() or 1
55+
n_workers = max(1, int(n_cores * 0.95))
56+
57+
if len(texts) <= n_workers:
58+
# Few texts — just tokenize directly, no threading overhead
59+
encoded = tokenizer(texts, add_special_tokens=False)
60+
return [len(ids) for ids in encoded["input_ids"]]
61+
62+
# Split texts into chunks, one per worker
63+
chunk_size = (len(texts) + n_workers - 1) // n_workers
64+
chunks = [texts[i : i + chunk_size] for i in range(0, len(texts), chunk_size)]
65+
66+
def _tokenize_chunk(chunk: list[str]) -> list[int]:
67+
encoded = tokenizer(chunk, add_special_tokens=False)
68+
return [len(ids) for ids in encoded["input_ids"]]
69+
70+
results: list[int] = []
71+
with ThreadPoolExecutor(max_workers=n_workers) as pool:
72+
for chunk_lengths in pool.map(_tokenize_chunk, chunks):
73+
results.extend(chunk_lengths)
74+
return results
75+
4276

4377
class TPOTReportingMode(str, Enum):
4478
"""TPOT (Time Per Output Token) reporting mode.
@@ -1016,7 +1050,9 @@ def get_output_sequence_lengths(
10161050
"""
10171051
query_result = self.get_sample_outputs()
10181052

1019-
rows = []
1053+
# Collect all texts for batch tokenization
1054+
uuids: list[str] = []
1055+
texts: list[str] = []
10201056
for sample_uuid, data_bytes in query_result:
10211057
output_sequence, reasoning_sequence = output_sequence_from_data(data_bytes)
10221058

@@ -1029,13 +1065,16 @@ def get_output_sequence_lengths(
10291065
else:
10301066
full_sequence = output_sequence
10311067

1032-
# Tokenize and calculate length
1033-
output_tokens = tokenizer.tokenize(full_sequence)
1034-
rows.append((sample_uuid, len(output_tokens)))
1068+
uuids.append(sample_uuid)
1069+
texts.append(full_sequence)
10351070

1036-
if not rows:
1071+
if not texts:
10371072
return None
10381073

1074+
# Parallel batch tokenize across ~95% of cores
1075+
token_counts = _parallel_batch_tokenize(tokenizer, texts)
1076+
rows = list(zip(uuids, token_counts, strict=False))
1077+
10391078
return RollupQueryTable("output_sequence_length", None, rows)
10401079

10411080
@profile
@@ -1085,11 +1124,9 @@ def derive_TPOT(
10851124
if not query_result:
10861125
return None
10871126

1088-
rows = []
1089-
if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED:
1090-
repeats = []
1091-
else:
1092-
repeats = None
1127+
# Pass 1: Collect all non-first-chunk texts for batch tokenization
1128+
batch_uuids: list[str] = []
1129+
batch_texts: list[str] = []
10931130

10941131
for sample_uuid, data_bytes in query_result:
10951132
if data_bytes is None or len(data_bytes) == 0:
@@ -1137,9 +1174,25 @@ def derive_TPOT(
11371174
# Possible malformed output data where empty string is included as a non-first chunk
11381175
continue
11391176

1140-
non_first_tokens = tokenizer.tokenize(non_first_chunk)
1141-
n_non_first_tokens = len(non_first_tokens)
1177+
batch_uuids.append(sample_uuid)
1178+
batch_texts.append(non_first_chunk)
1179+
1180+
if not batch_texts:
1181+
return None
1182+
1183+
# Parallel batch tokenize across ~95% of cores
1184+
token_counts = _parallel_batch_tokenize(tokenizer, batch_texts)
1185+
1186+
# Pass 2: Compute TPOT using batch-tokenized results
1187+
rows = []
1188+
if condense_table and reporting_mode == TPOTReportingMode.TOKEN_WEIGHTED:
1189+
repeats = []
1190+
else:
1191+
repeats = None
11421192

1193+
for sample_uuid, n_non_first_tokens in zip(
1194+
batch_uuids, token_counts, strict=False
1195+
):
11431196
latency = sample_latency_rollup.filter_uuid(sample_uuid, only_first=True)
11441197
if latency is None:
11451198
raise SampleUUIDNotFoundError(sample_uuid, "events record")

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ class CharacterTokenizer:
307307
def tokenize(self, text: str) -> list[str]:
308308
return list(text)
309309

310+
def __call__(
311+
self, texts: list[str], **kwargs: object
312+
) -> dict[str, list[list[int]]]:
313+
return {"input_ids": [list(range(len(t))) for t in texts]}
314+
310315

311316
@pytest.fixture
312317
def tokenizer():

0 commit comments

Comments
 (0)