Skip to content

Commit bd111bc

Browse files
majiayu000kylesayrsdsikkaHDCharlesclaude
authored
fix: suppress tokenizer parallelism warning in oneshot (vllm-project#2183)
SUMMARY: Suppress the tokenizer parallelism warning that appears during oneshot calibration by setting `TOKENIZERS_PARALLELISM=false` in `Oneshot.__init__`. The warning occurs when FastTokenizer's internal threading conflicts with `dataset.map`'s multiprocessing (`num_proc` parameter). This fix sets the environment variable early in the oneshot lifecycle to prevent the conflict, while respecting any existing user-set value. Closes vllm-project#2007 TEST PLAN: - Added unit tests in `tests/llmcompressor/transformers/oneshot/test_tokenizer_parallelism.py` - Tests verify: 1. `TOKENIZERS_PARALLELISM` is set to `false` when not already set 2. Existing user-set `TOKENIZERS_PARALLELISM` values are respected - All tests pass locally with `pytest tests/llmcompressor/transformers/oneshot/test_tokenizer_parallelism.py -v` --------- Signed-off-by: majiayu000 <1835304752@qq.com> Signed-off-by: lif <1835304752@qq.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: HDCharles <39544797+HDCharles@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent 6cf3d37 commit bd111bc

File tree

5 files changed

+67
-4
lines changed

5 files changed

+67
-4
lines changed

experimental/attention/llama3_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
12
from datasets import load_dataset
23
from transformers import AutoModelForCausalLM, AutoTokenizer
34

45
from llmcompressor import oneshot
56
from llmcompressor.modifiers.quantization import QuantizationModifier
67
from llmcompressor.utils import dispatch_for_generation
7-
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs
88

99
# Select model and load it.
1010
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

experimental/attention/llama3_attention_r3_nvfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
from compressed_tensors.quantization import QuantizationScheme
2+
from compressed_tensors.quantization.quant_scheme import NVFP4
13
from datasets import load_dataset
24
from transformers import AutoModelForCausalLM, AutoTokenizer
35

46
from llmcompressor import oneshot
57
from llmcompressor.modifiers.quantization import QuantizationModifier
68
from llmcompressor.modifiers.transform import SpinQuantModifier
79
from llmcompressor.utils import dispatch_for_generation
8-
from compressed_tensors.quantization import QuantizationScheme
9-
from compressed_tensors.quantization.quant_scheme import NVFP4
1010

1111
# Select model and load it.
1212
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from datasets import Dataset, DatasetDict
3232

3333

34+
TOKENIZERS_PARALLELISM_ENV = "TOKENIZERS_PARALLELISM"
35+
36+
3437
class Oneshot:
3538
"""
3639
Class responsible for carrying out one-shot calibration on a pretrained model.
@@ -121,6 +124,19 @@ def __init__(
121124
:param log_dir: Path to save logs during oneshot run.
122125
Nothing is logged to file if None.
123126
"""
127+
# Disable tokenizer parallelism to prevent warning when using
128+
# multiprocessing for dataset preprocessing. The warning occurs because
129+
# FastTokenizer's internal threading conflicts with dataset.map's num_proc.
130+
# See: https://github.com/vllm-project/llm-compressor/issues/2007
131+
if TOKENIZERS_PARALLELISM_ENV not in os.environ:
132+
os.environ[TOKENIZERS_PARALLELISM_ENV] = "false"
133+
logger.warning(
134+
"Disabling tokenizer parallelism due to threading conflict between "
135+
"FastTokenizer and Datasets. Set "
136+
f"{TOKENIZERS_PARALLELISM_ENV}=false to "
137+
"suppress this warning."
138+
)
139+
124140
# Set up file logging (no default files):
125141
# 1) If LLM_COMPRESSOR_LOG_FILE is set, log to that file.
126142
# 2) Else, if an explicit log_dir is provided, create a timestamped file there.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
3+
import pytest
4+
5+
from llmcompressor.entrypoints.oneshot import (
6+
TOKENIZERS_PARALLELISM_ENV as _TOKENIZERS_PARALLELISM_ENV,
7+
)
8+
9+
10+
class TestTokenizerParallelism:
11+
"""Tests for tokenizer parallelism warning suppression (issue #2007)."""
12+
13+
def test_oneshot_sets_tokenizers_parallelism_when_not_set(self, monkeypatch):
14+
"""
15+
Test that Oneshot sets TOKENIZERS_PARALLELISM=false when not already set.
16+
17+
This prevents the warning:
18+
"huggingface/tokenizers: The current process just got forked, after
19+
parallelism has already been used. Disabling parallelism to avoid deadlocks..."
20+
21+
See: https://github.com/vllm-project/llm-compressor/issues/2007
22+
"""
23+
monkeypatch.delenv(_TOKENIZERS_PARALLELISM_ENV, raising=False)
24+
25+
from llmcompressor.entrypoints.oneshot import Oneshot
26+
27+
# Create a minimal Oneshot instance to trigger __init__
28+
# We expect it to fail due to missing model, but the env var should be set
29+
with pytest.raises(Exception):
30+
Oneshot(model="nonexistent-model")
31+
32+
assert os.environ[_TOKENIZERS_PARALLELISM_ENV] == "false"
33+
34+
def test_oneshot_respects_existing_tokenizers_parallelism(self, monkeypatch):
35+
"""
36+
Test that Oneshot respects user's existing TOKENIZERS_PARALLELISM setting.
37+
38+
If a user has explicitly set TOKENIZERS_PARALLELISM, we should not override it.
39+
"""
40+
monkeypatch.setenv(_TOKENIZERS_PARALLELISM_ENV, "true")
41+
42+
from llmcompressor.entrypoints.oneshot import Oneshot
43+
44+
with pytest.raises(Exception):
45+
Oneshot(model="nonexistent-model")
46+
47+
assert os.environ[_TOKENIZERS_PARALLELISM_ENV] == "true"

tools/collect_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
creating bug reports. See `.github/ISSUE_TEMPLATE/bug_report.md`
44
"""
55

6+
import importlib
67
import platform
78
import sys
8-
import importlib
99

1010

1111
def get_version(pkg_name):

0 commit comments

Comments
 (0)