Skip to content

Commit 43a3769

Browse files
committed
Move thread-local storage inside TokenizePool instance. Remove unused .tokenize call - only .token_count is used
1 parent 606a5f4 commit 43a3769

File tree

5 files changed

+139
-125
lines changed

5 files changed

+139
-125
lines changed

src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import argparse
1919
import asyncio
20+
from contextlib import AbstractContextManager, nullcontext
2021
from pathlib import Path
2122

2223
from inference_endpoint.async_utils.loop_manager import LoopManager
@@ -63,27 +64,31 @@ async def main() -> None:
6364
shutdown_event = asyncio.Event()
6465
loop = LoopManager().default_loop
6566

66-
pool = None
67+
# Using ternary operator causes errors in MyPy object type coalescing
68+
# (coalesces to 'object' not 'AbstractContextManager[TokenizePool | None]')
6769
if args.tokenizer:
68-
pool = TokenizePool(args.tokenizer, n_workers=args.tokenizer_workers)
70+
pool_cm: AbstractContextManager[TokenizePool | None] = TokenizePool(
71+
args.tokenizer, n_workers=args.tokenizer_workers
72+
)
73+
else:
74+
pool_cm = nullcontext()
6975

70-
try:
71-
with ManagedZMQContext.scoped(socket_dir=args.metrics_dir.parent) as zmq_ctx:
72-
emitter = JsonlMetricEmitter(metrics_file, flush_interval=100)
73-
aggregator = MetricsAggregatorService(
74-
args.socket_address,
75-
zmq_ctx,
76-
loop,
77-
topics=None,
78-
emitter=emitter,
79-
tokenize_pool=pool,
80-
shutdown_event=shutdown_event,
81-
)
82-
loop.call_soon_threadsafe(aggregator.start)
83-
await shutdown_event.wait()
84-
finally:
85-
if pool is not None:
86-
pool.close()
76+
with (
77+
pool_cm as pool,
78+
ManagedZMQContext.scoped(socket_dir=args.metrics_dir.parent) as zmq_ctx,
79+
):
80+
emitter = JsonlMetricEmitter(metrics_file, flush_interval=100)
81+
aggregator = MetricsAggregatorService(
82+
args.socket_address,
83+
zmq_ctx,
84+
loop,
85+
topics=None,
86+
emitter=emitter,
87+
tokenize_pool=pool,
88+
shutdown_event=shutdown_event,
89+
)
90+
loop.call_soon_threadsafe(aggregator.start)
91+
await shutdown_event.wait()
8792

8893

8994
if __name__ == "__main__":

src/inference_endpoint/async_utils/services/metrics_aggregator/emitter.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import time
2121
from abc import ABC, abstractmethod
2222
from pathlib import Path
23+
from typing import TextIO
2324

2425
import msgspec
2526

@@ -58,12 +59,14 @@ class JsonlMetricEmitter(MetricEmitter):
5859

5960
def __init__(self, file_path: Path, flush_interval: int = 100) -> None:
6061
self._file_path = file_path.with_suffix(".jsonl")
61-
self._file = self._file_path.open("w")
62+
self._file: TextIO | None = self._file_path.open("w")
6263
self._encoder = msgspec.json.Encoder()
6364
self._flush_interval = flush_interval
6465
self._n_since_flush = 0
6566

6667
def emit(self, sample_uuid: str, metric_name: str, value: int | float) -> None:
68+
if self._file is None:
69+
return
6770
record = _MetricRecord(
6871
sample_uuid=sample_uuid,
6972
metric_name=metric_name,
@@ -89,4 +92,4 @@ def close(self) -> None:
8992
# File may already be closed or I/O error on close (e.g. disk full).
9093
pass
9194
finally:
92-
self._file = None # type: ignore[assignment]
95+
self._file = None

src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,6 @@
2828
from transformers import PreTrainedTokenizerBase
2929

3030

31-
# Create a thread-local storage for the tokenizer so each thread contains its own instance.
32-
_thread_local = threading.local()
33-
34-
35-
def _get_thread_tokenizer(tokenizer_name: str) -> PreTrainedTokenizerBase:
36-
"""Return the tokenizer for the current thread, loading it if needed."""
37-
if not hasattr(_thread_local, "tokenizer") or _thread_local.tokenizer is None:
38-
_thread_local.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
39-
return _thread_local.tokenizer
40-
41-
42-
def _tokenize_worker(tokenizer_name: str, text: str) -> list[str]:
43-
"""Worker entry: load tokenizer for this thread and tokenize."""
44-
tokenizer = _get_thread_tokenizer(tokenizer_name)
45-
return tokenizer.tokenize(text)
46-
47-
48-
def _token_count_worker(tokenizer_name: str, text: str) -> int:
49-
"""Worker entry: return the number of tokens in text."""
50-
tokenizer = _get_thread_tokenizer(tokenizer_name)
51-
return len(tokenizer.encode(text))
52-
53-
5431
class TokenizePool:
5532
"""A pool of worker threads, each with its own HuggingFace AutoTokenizer.
5633
@@ -64,35 +41,56 @@ class TokenizePool:
6441
- The ThreadPoolExecutor itself is thread-safe (submit/shutdown are synchronized).
6542
- Each worker thread has its own tokenizer via thread-local storage, so there
6643
is no shared mutable state during tokenization.
67-
- The blocking `tokenize()` / `token_count()` methods are safe to call from
68-
multiple threads concurrently.
69-
- In an async context, use the `_async` variants to avoid blocking the event loop.
70-
These use `loop.run_in_executor(None, ...)` to offload to the default executor,
71-
which then submits to the TokenizePool's own ThreadPoolExecutor.
44+
- The blocking `token_count()` method is safe to call from multiple threads
45+
concurrently.
46+
- In an async context, use `token_count_async` to avoid blocking the event loop.
7247
"""
7348

7449
def __init__(self, tokenizer_name: str, n_workers: int) -> None:
7550
if n_workers < 1:
7651
raise ValueError("n_workers must be at least 1")
7752
self._tokenizer_name = tokenizer_name
7853
self._n_workers = n_workers
54+
self._thread_local = threading.local()
7955
self._executor: ThreadPoolExecutor | None = ThreadPoolExecutor(
8056
max_workers=n_workers,
8157
thread_name_prefix="TokenizePool",
8258
)
83-
84-
def tokenize(self, text: str) -> list[str]:
85-
"""Tokenize the input string via the worker pool (blocking)."""
86-
if self._executor is None:
87-
raise RuntimeError("TokenizePool is closed")
88-
future = self._executor.submit(_tokenize_worker, self._tokenizer_name, text)
89-
return future.result()
59+
# Pre-load a tokenizer on every worker thread so the first real
60+
# token_count call doesn't pay the AutoTokenizer.from_pretrained cost.
61+
# Submitting n_workers tasks is guaranteed to hit every thread because
62+
# AutoTokenizer.from_pretrained blocks long enough that no thread
63+
# completes before all tasks are submitted.
64+
# **IMPORTANT**: This is not a guarantee - for instance when using a mock
65+
# object in tests for the tokenizer, the mock object *must* block in the 100ms
66+
# range to simulate proper .from_pretrained behavior.
67+
# It is not super impactful if a thread is not pre-initialized - it will just
68+
# have to pay the cost of .from_pretrained on the first pool.token_count call
69+
# for that thread.
70+
futures = [
71+
self._executor.submit(self._get_thread_tokenizer) for _ in range(n_workers)
72+
]
73+
for f in futures:
74+
f.result()
75+
76+
def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase:
77+
"""Return the tokenizer for the current thread, loading it if needed."""
78+
if getattr(self._thread_local, "tokenizer", None) is None:
79+
self._thread_local.tokenizer = AutoTokenizer.from_pretrained(
80+
self._tokenizer_name
81+
)
82+
return self._thread_local.tokenizer
83+
84+
def _token_count_worker(self, text: str) -> int:
85+
"""Worker entry: return the number of tokens in text."""
86+
tokenizer = self._get_thread_tokenizer()
87+
return len(tokenizer.tokenize(text))
9088

9189
def token_count(self, text: str) -> int:
9290
"""Return the number of tokens in the input string (blocking)."""
9391
if self._executor is None:
9492
raise RuntimeError("TokenizePool is closed")
95-
future = self._executor.submit(_token_count_worker, self._tokenizer_name, text)
93+
future = self._executor.submit(self._token_count_worker, text)
9694
return future.result()
9795

9896
async def token_count_async(
@@ -106,7 +104,7 @@ async def token_count_async(
106104
if self._executor is None:
107105
raise RuntimeError("TokenizePool is closed")
108106
return await loop.run_in_executor(
109-
self._executor, _token_count_worker, self._tokenizer_name, text
107+
self._executor, self._token_count_worker, text
110108
)
111109

112110
def close(self) -> None:

tests/unit/async_utils/services/event_logger/test_event_logger.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _make_stub(*args, **kwargs) -> tuple[StubEventLoggerService, list[FakeWriter
117117

118118
@pytest.mark.unit
119119
class TestWriteDispatch:
120-
@pytest.mark.asyncio(mode="strict")
120+
@pytest.mark.asyncio
121121
@pytest.mark.parametrize(
122122
"case_desc, records",
123123
[
@@ -146,14 +146,14 @@ async def test_records_written_to_all_writers(self, case_desc, records):
146146
for writer in writers:
147147
assert writer.written == records
148148

149-
@pytest.mark.asyncio(mode="strict")
149+
@pytest.mark.asyncio
150150
async def test_empty_batch(self):
151151
service, writers = _make_stub()
152152
await service.process([])
153153
for writer in writers:
154154
assert len(writer.written) == 0
155155

156-
@pytest.mark.asyncio(mode="strict")
156+
@pytest.mark.asyncio
157157
async def test_multiple_batches_accumulate(self):
158158
service, writers = _make_stub()
159159
await service.process([_record(SampleEventType.ISSUED, uuid="s1")])
@@ -169,15 +169,15 @@ async def test_multiple_batches_accumulate(self):
169169

170170
@pytest.mark.unit
171171
class TestShutdownBehavior:
172-
@pytest.mark.asyncio(mode="strict")
172+
@pytest.mark.asyncio
173173
async def test_session_ended_triggers_flush_and_close(self):
174174
service, writers = _make_stub()
175175
await service.process([_record(SessionEventType.ENDED, ts=100)])
176176
for writer in writers:
177177
assert writer.flush_count == 1
178178
assert writer.closed
179179

180-
@pytest.mark.asyncio(mode="strict")
180+
@pytest.mark.asyncio
181181
@pytest.mark.parametrize(
182182
"case_desc, trailing_record",
183183
[
@@ -202,13 +202,13 @@ async def test_events_after_ended_same_batch(self, case_desc, trailing_record):
202202
assert len(writer.written) == 1
203203
assert writer.written[0].event_type == SessionEventType.ENDED
204204

205-
@pytest.mark.asyncio(mode="strict")
205+
@pytest.mark.asyncio
206206
async def test_writers_cleared_after_shutdown(self):
207207
service, _ = _make_stub()
208208
await service.process([_record(SessionEventType.ENDED)])
209209
assert service.writers == []
210210

211-
@pytest.mark.asyncio(mode="strict")
211+
@pytest.mark.asyncio
212212
async def test_records_before_ended_are_written(self):
213213
service, writers = _make_stub()
214214
await service.process(
@@ -235,15 +235,15 @@ async def test_records_before_ended_are_written(self):
235235

236236
@pytest.mark.unit
237237
class TestClose:
238-
@pytest.mark.asyncio(mode="strict")
238+
@pytest.mark.asyncio
239239
async def test_close_closes_all_writers(self):
240240
service, writers = _make_stub()
241241
service.close()
242242
for writer in writers:
243243
assert writer.closed
244244
assert service.writers == []
245245

246-
@pytest.mark.asyncio(mode="strict")
246+
@pytest.mark.asyncio
247247
async def test_close_idempotent(self):
248248
service, _ = _make_stub()
249249
service.close()
@@ -257,7 +257,7 @@ async def test_close_idempotent(self):
257257

258258
@pytest.mark.unit
259259
class TestIntegrationWithRealWriters:
260-
@pytest.mark.asyncio(mode="strict")
260+
@pytest.mark.asyncio
261261
async def test_jsonl_writer_integration(self, tmp_path):
262262
"""EventLoggerService with a real JSONLWriter persists records to disk."""
263263
writer = JSONLWriter(tmp_path / "events", flush_interval=1)
@@ -281,7 +281,7 @@ async def test_jsonl_writer_integration(self, tmp_path):
281281
assert records[1].event_type == SampleEventType.RECV_FIRST
282282
assert records[2].event_type == SampleEventType.COMPLETE
283283

284-
@pytest.mark.asyncio(mode="strict")
284+
@pytest.mark.asyncio
285285
async def test_sql_writer_integration(self, tmp_path):
286286
"""EventLoggerService with a real SQLWriter persists records to SQLite."""
287287
from sqlalchemy import create_engine, select
@@ -312,7 +312,7 @@ async def test_sql_writer_integration(self, tmp_path):
312312
]
313313
engine.dispose()
314314

315-
@pytest.mark.asyncio(mode="strict")
315+
@pytest.mark.asyncio
316316
async def test_dual_writer_integration(self, tmp_path):
317317
"""Both JSONL and SQL writers receive the same records."""
318318
jsonl_writer = JSONLWriter(tmp_path / "events", flush_interval=1)
@@ -343,7 +343,7 @@ async def test_dual_writer_integration(self, tmp_path):
343343
assert rows[0].sample_uuid == "dual-1"
344344
engine.dispose()
345345

346-
@pytest.mark.asyncio(mode="strict")
346+
@pytest.mark.asyncio
347347
async def test_ended_closes_real_writers(self, tmp_path):
348348
"""ENDED triggers close on real writers, flushing data to disk."""
349349
jsonl_writer = JSONLWriter(tmp_path / "events", flush_interval=100)
@@ -361,7 +361,7 @@ async def test_ended_closes_real_writers(self, tmp_path):
361361
lines = [line for line in content.split("\n") if line]
362362
assert len(lines) == 2
363363

364-
@pytest.mark.asyncio(mode="strict")
364+
@pytest.mark.asyncio
365365
async def test_events_after_ended_not_persisted_to_jsonl(self, tmp_path):
366366
"""All events after ENDED (including errors) are dropped from JSONL."""
367367
writer = JSONLWriter(tmp_path / "events", flush_interval=100)
@@ -380,7 +380,7 @@ async def test_events_after_ended_not_persisted_to_jsonl(self, tmp_path):
380380
assert len(lines) == 1
381381
assert "LateError" not in lines[0]
382382

383-
@pytest.mark.asyncio(mode="strict")
383+
@pytest.mark.asyncio
384384
async def test_full_lifecycle(self, tmp_path):
385385
"""Full session lifecycle: started -> samples -> ended."""
386386
writer = JSONLWriter(tmp_path / "events", flush_interval=1)
@@ -418,7 +418,7 @@ async def test_full_lifecycle(self, tmp_path):
418418

419419
@pytest.mark.unit
420420
class TestEdgeCases:
421-
@pytest.mark.asyncio(mode="strict")
421+
@pytest.mark.asyncio
422422
@pytest.mark.parametrize(
423423
"case_desc, event_enum, make_record",
424424
[
@@ -439,7 +439,7 @@ async def test_all_event_types_written(self, case_desc, event_enum, make_record)
439439
for writer in writers:
440440
assert len(writer.written) == len(list(event_enum))
441441

442-
@pytest.mark.asyncio(mode="strict")
442+
@pytest.mark.asyncio
443443
async def test_ended_only_triggers_once(self):
444444
"""Multiple ENDED in a batch: shutdown path runs once, second ENDED is dropped."""
445445
service, writers = _make_stub()

0 commit comments

Comments
 (0)