3939if TYPE_CHECKING :
4040 from transformers import Tokenizer
4141
42+ logger = logging .getLogger (__name__ )
43+
44+
45+ def _parallel_batch_tokenize (tokenizer : Tokenizer , texts : list [str ]) -> list [int ]:
46+ """Batch-tokenize texts using all available cores and return token counts.
47+
48+ Uses a ThreadPoolExecutor to parallelize across ~95% of CPU cores.
49+ HuggingFace tokenizers use a Rust backend that releases the GIL,
50+ so threads achieve real parallelism without GIL contention.
51+ """
52+ from concurrent .futures import ThreadPoolExecutor
53+
54+ n_cores = os .cpu_count () or 1
55+ n_workers = max (1 , int (n_cores * 0.95 ))
56+
57+ if len (texts ) <= n_workers :
58+ # Few texts — just tokenize directly, no threading overhead
59+ encoded = tokenizer (texts , add_special_tokens = False )
60+ return [len (ids ) for ids in encoded ["input_ids" ]]
61+
62+ # Split texts into chunks, one per worker
63+ chunk_size = (len (texts ) + n_workers - 1 ) // n_workers
64+ chunks = [texts [i : i + chunk_size ] for i in range (0 , len (texts ), chunk_size )]
65+
66+ def _tokenize_chunk (chunk : list [str ]) -> list [int ]:
67+ encoded = tokenizer (chunk , add_special_tokens = False )
68+ return [len (ids ) for ids in encoded ["input_ids" ]]
69+
70+ results : list [int ] = []
71+ with ThreadPoolExecutor (max_workers = n_workers ) as pool :
72+ for chunk_lengths in pool .map (_tokenize_chunk , chunks ):
73+ results .extend (chunk_lengths )
74+ return results
75+
4276
4377class TPOTReportingMode (str , Enum ):
4478 """TPOT (Time Per Output Token) reporting mode.
@@ -1016,7 +1050,9 @@ def get_output_sequence_lengths(
10161050 """
10171051 query_result = self .get_sample_outputs ()
10181052
1019- rows = []
1053+ # Collect all texts for batch tokenization
1054+ uuids : list [str ] = []
1055+ texts : list [str ] = []
10201056 for sample_uuid , data_bytes in query_result :
10211057 output_sequence , reasoning_sequence = output_sequence_from_data (data_bytes )
10221058
@@ -1029,13 +1065,16 @@ def get_output_sequence_lengths(
10291065 else :
10301066 full_sequence = output_sequence
10311067
1032- # Tokenize and calculate length
1033- output_tokens = tokenizer .tokenize (full_sequence )
1034- rows .append ((sample_uuid , len (output_tokens )))
1068+ uuids .append (sample_uuid )
1069+ texts .append (full_sequence )
10351070
1036- if not rows :
1071+ if not texts :
10371072 return None
10381073
1074+ # Parallel batch tokenize across ~95% of cores
1075+ token_counts = _parallel_batch_tokenize (tokenizer , texts )
1076+ rows = list (zip (uuids , token_counts , strict = False ))
1077+
10391078 return RollupQueryTable ("output_sequence_length" , None , rows )
10401079
10411080 @profile
@@ -1085,11 +1124,9 @@ def derive_TPOT(
10851124 if not query_result :
10861125 return None
10871126
1088- rows = []
1089- if condense_table and reporting_mode == TPOTReportingMode .TOKEN_WEIGHTED :
1090- repeats = []
1091- else :
1092- repeats = None
1127+ # Pass 1: Collect all non-first-chunk texts for batch tokenization
1128+ batch_uuids : list [str ] = []
1129+ batch_texts : list [str ] = []
10931130
10941131 for sample_uuid , data_bytes in query_result :
10951132 if data_bytes is None or len (data_bytes ) == 0 :
@@ -1137,9 +1174,25 @@ def derive_TPOT(
11371174 # Possible malformed output data where empty string is included as a non-first chunk
11381175 continue
11391176
1140- non_first_tokens = tokenizer .tokenize (non_first_chunk )
1141- n_non_first_tokens = len (non_first_tokens )
1177+ batch_uuids .append (sample_uuid )
1178+ batch_texts .append (non_first_chunk )
1179+
1180+ if not batch_texts :
1181+ return None
1182+
1183+ # Parallel batch tokenize across ~95% of cores
1184+ token_counts = _parallel_batch_tokenize (tokenizer , batch_texts )
1185+
1186+ # Pass 2: Compute TPOT using batch-tokenized results
1187+ rows = []
1188+ if condense_table and reporting_mode == TPOTReportingMode .TOKEN_WEIGHTED :
1189+ repeats = []
1190+ else :
1191+ repeats = None
11421192
1193+ for sample_uuid , n_non_first_tokens in zip (
1194+ batch_uuids , token_counts , strict = False
1195+ ):
11431196 latency = sample_latency_rollup .filter_uuid (sample_uuid , only_first = True )
11441197 if latency is None :
11451198 raise SampleUUIDNotFoundError (sample_uuid , "events record" )
0 commit comments