Skip to content

Commit d6e5d7b

Browse files
viraatcclaude
andcommitted
fix: address PR review comments
- Move ThreadPoolExecutor import to module level (PEP 8) - Remove unused logger variable - Use strict=True in zip() calls to catch length mismatches - Add comment explaining early-stop timing in session loop Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 07972c7 commit d6e5d7b

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/inference_endpoint/load_generator/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def _run_test(
9292
data=get_version_info(),
9393
)
9494

95+
# Note: stop_requested is checked after each iteration, so one
96+
# additional sample may be issued after the flag is set. This is
97+
# acceptable — the alternative (checking before next()) would
98+
# require breaking the generator protocol.
9599
for _ in perf_test_generator:
96100
if self.stop_requested:
97101
self.logger.info(

src/inference_endpoint/metrics/reporter.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import sqlite3
2626
from collections import defaultdict
2727
from collections.abc import Callable, Iterable
28+
from concurrent.futures import ThreadPoolExecutor
2829
from enum import Enum
2930
from pathlib import Path
3031
from typing import TYPE_CHECKING, Any
@@ -39,8 +40,6 @@
3940
if TYPE_CHECKING:
4041
from transformers import Tokenizer
4142

42-
logger = logging.getLogger(__name__)
43-
4443

4544
def _parallel_batch_tokenize(tokenizer: Tokenizer, texts: list[str]) -> list[int]:
4645
"""Batch-tokenize texts using all available cores and return token counts.
@@ -49,7 +48,6 @@ def _parallel_batch_tokenize(tokenizer: Tokenizer, texts: list[str]) -> list[int
4948
HuggingFace tokenizers use a Rust backend that releases the GIL,
5049
so threads achieve real parallelism without GIL contention.
5150
"""
52-
from concurrent.futures import ThreadPoolExecutor
5351

5452
n_cores = os.cpu_count() or 1
5553
n_workers = max(1, int(n_cores * 0.95))
@@ -1093,7 +1091,7 @@ def get_output_sequence_lengths(
10931091

10941092
# Parallel batch tokenize across ~95% of cores
10951093
token_counts = _parallel_batch_tokenize(tokenizer, texts)
1096-
rows = list(zip(uuids, token_counts, strict=False))
1094+
rows = list(zip(uuids, token_counts, strict=True))
10971095

10981096
return RollupQueryTable("output_sequence_length", None, rows)
10991097

@@ -1211,7 +1209,7 @@ def derive_TPOT(
12111209
repeats = None
12121210

12131211
for sample_uuid, n_non_first_tokens in zip(
1214-
batch_uuids, token_counts, strict=False
1212+
batch_uuids, token_counts, strict=True
12151213
):
12161214
latency = sample_latency_rollup.filter_uuid(sample_uuid, only_first=True)
12171215
if latency is None:

0 commit comments

Comments
 (0)