Skip to content

Commit 39cefbd

Browse files
[Refactor] TokenizerRegistry only uses lazy imports (vllm-project#30609)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent ace34e3 commit 39cefbd

File tree

14 files changed

+202
-176
lines changed

14 files changed

+202
-176
lines changed

tests/test_inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from vllm.inputs import zip_enc_dec_prompts
88
from vllm.inputs.parse import parse_raw_prompts
99
from vllm.inputs.preprocess import InputPreprocessor
10-
from vllm.tokenizers import init_tokenizer_from_config
10+
from vllm.tokenizers import cached_tokenizer_from_config
1111

1212
pytestmark = pytest.mark.cpu_test
1313

@@ -108,7 +108,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
108108
)
109109
def test_preprocessor_always_mm_code_path(model_id, prompt):
110110
model_config = ModelConfig(model=model_id)
111-
tokenizer = init_tokenizer_from_config(model_config)
111+
tokenizer = cached_tokenizer_from_config(model_config)
112112
input_preprocessor = InputPreprocessor(model_config, tokenizer)
113113

114114
# HF processor adds sep token

tests/tokenizers_/test_basic.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,39 @@
33
from typing import _get_protocol_attrs # type: ignore
44

55
import pytest
6-
from transformers import PreTrainedTokenizerBase
6+
from transformers import (
7+
PreTrainedTokenizer,
8+
PreTrainedTokenizerBase,
9+
PreTrainedTokenizerFast,
10+
)
711

812
from vllm.tokenizers import TokenizerLike, get_tokenizer
13+
from vllm.tokenizers.mistral import MistralTokenizer
914

1015

1116
def _get_missing_attrs(obj: object, target: type):
1217
return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)]
1318

1419

20+
def _assert_tokenizer_like(tokenizer: object):
21+
missing_attrs = _get_missing_attrs(tokenizer, TokenizerLike)
22+
assert not missing_attrs, f"Missing attrs: {missing_attrs}"
23+
24+
1525
def test_tokenizer_like_protocol():
16-
assert not (
17-
missing_attrs := _get_missing_attrs(
18-
get_tokenizer("gpt2", use_fast=False),
19-
TokenizerLike,
20-
)
21-
), f"Missing attrs: {missing_attrs}"
22-
23-
assert not (
24-
missing_attrs := _get_missing_attrs(
25-
get_tokenizer("gpt2", use_fast=True),
26-
TokenizerLike,
27-
)
28-
), f"Missing attrs: {missing_attrs}"
29-
30-
assert not (
31-
missing_attrs := _get_missing_attrs(
32-
get_tokenizer(
33-
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
34-
),
35-
TokenizerLike,
36-
)
37-
), f"Missing attrs: {missing_attrs}"
26+
tokenizer = get_tokenizer("gpt2", use_fast=False)
27+
assert isinstance(tokenizer, PreTrainedTokenizer)
28+
_assert_tokenizer_like(tokenizer)
29+
30+
tokenizer = get_tokenizer("gpt2", use_fast=True)
31+
assert isinstance(tokenizer, PreTrainedTokenizerFast)
32+
_assert_tokenizer_like(tokenizer)
33+
34+
tokenizer = get_tokenizer(
35+
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
36+
)
37+
assert isinstance(tokenizer, MistralTokenizer)
38+
_assert_tokenizer_like(tokenizer)
3839

3940

4041
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])

tests/tokenizers_/test_registry.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from pathlib import Path
44

5-
from vllm.tokenizers import TokenizerLike, TokenizerRegistry, get_tokenizer
5+
import pytest
6+
7+
from vllm.tokenizers import TokenizerLike
8+
from vllm.tokenizers.registry import (
9+
TokenizerRegistry,
10+
get_tokenizer,
11+
resolve_tokenizer_args,
12+
)
613

714

815
class TestTokenizer(TokenizerLike):
@@ -40,10 +47,22 @@ def is_fast(self) -> bool:
4047
return True
4148

4249

50+
@pytest.mark.parametrize("runner_type", ["generate", "pooling"])
51+
def test_resolve_tokenizer_args_idempotent(runner_type):
52+
tokenizer_mode, tokenizer_name, args, kwargs = resolve_tokenizer_args(
53+
"facebook/opt-125m",
54+
runner_type=runner_type,
55+
)
56+
57+
assert (tokenizer_mode, tokenizer_name, args, kwargs) == resolve_tokenizer_args(
58+
tokenizer_name, *args, **kwargs
59+
)
60+
61+
4362
def test_customized_tokenizer():
4463
TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)
4564

46-
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer", "abc")
65+
tokenizer = TokenizerRegistry.load_tokenizer("test_tokenizer", "abc")
4766
assert isinstance(tokenizer, TestTokenizer)
4867
assert tokenizer.path_or_repo_id == "abc"
4968
assert tokenizer.bos_token_id == 0

vllm/entrypoints/chat_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
5151
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
5252
from vllm.tokenizers import TokenizerLike
53-
from vllm.tokenizers.mistral import MistralTokenizer
5453
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
5554
from vllm.transformers_utils.processor import cached_get_processor
5655
from vllm.utils import random_uuid
@@ -60,6 +59,8 @@
6059

6160
if TYPE_CHECKING:
6261
import torch
62+
63+
from vllm.tokenizers.mistral import MistralTokenizer
6364
else:
6465
torch = LazyLoader("torch", globals(), "torch")
6566

@@ -1832,7 +1833,7 @@ def apply_hf_chat_template(
18321833

18331834

18341835
def apply_mistral_chat_template(
1835-
tokenizer: MistralTokenizer,
1836+
tokenizer: "MistralTokenizer",
18361837
messages: list[ChatCompletionMessageParam],
18371838
chat_template: str | None,
18381839
tools: list[dict[str, Any]] | None,

vllm/tokenizers/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from .deepseekv32 import DeepseekV32Tokenizer
5-
from .hf import HfTokenizer
6-
from .mistral import MistralTokenizer
74
from .protocol import TokenizerLike
85
from .registry import (
96
TokenizerRegistry,
@@ -15,12 +12,9 @@
1512

1613
__all__ = [
1714
"TokenizerLike",
18-
"HfTokenizer",
19-
"MistralTokenizer",
2015
"TokenizerRegistry",
2116
"cached_get_tokenizer",
2217
"get_tokenizer",
2318
"cached_tokenizer_from_config",
2419
"init_tokenizer_from_config",
25-
"DeepseekV32Tokenizer",
2620
]

vllm/tokenizers/deepseekv32.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,18 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from pathlib import Path
5+
from typing import Any
56

67
from transformers import BatchEncoding
78

8-
from .deepseek_v32_encoding import encode_messages
9-
from .hf import HfTokenizer, TokenizerLike
10-
from .registry import TokenizerRegistry
9+
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
1110

11+
from .deepseek_v32_encoding import encode_messages
12+
from .hf import CachedHfTokenizer
13+
from .protocol import TokenizerLike
1214

13-
@TokenizerRegistry.register("deepseek_v32")
14-
class DeepseekV32Tokenizer(HfTokenizer):
15-
def __init__(self, tokenizer: TokenizerLike):
16-
self.tokenizer = tokenizer
17-
self.name_or_path = (
18-
tokenizer.name_or_path if hasattr(tokenizer, "name_or_path") else ""
19-
)
20-
self._added_vocab = self.tokenizer.get_added_vocab()
21-
self._added_vocab_size = len(self._added_vocab)
2215

16+
class DeepseekV32Tokenizer(CachedHfTokenizer):
2317
@classmethod
2418
def from_pretrained(
2519
cls,
@@ -40,7 +34,21 @@ def from_pretrained(
4034
)
4135
return DeepseekV32Tokenizer(tokenizer)
4236

43-
def apply_chat_template(self, messages, tools=None, **kwargs):
37+
def __init__(self, tokenizer: TokenizerLike) -> None:
38+
super().__init__()
39+
40+
self.tokenizer = tokenizer
41+
self.name_or_path = getattr(tokenizer, "name_or_path", "")
42+
43+
self._added_vocab = self.tokenizer.get_added_vocab()
44+
self._added_vocab_size = len(self._added_vocab)
45+
46+
def apply_chat_template(
47+
self,
48+
messages: list["ChatCompletionMessageParam"],
49+
tools: list[dict[str, Any]] | None = None,
50+
**kwargs,
51+
) -> str | list[int]:
4452
thinking = kwargs.get("thinking", False)
4553
thinking_mode = "thinking"
4654
if not thinking:
@@ -49,13 +57,24 @@ def apply_chat_template(self, messages, tools=None, **kwargs):
4957
messages = conversation.copy()
5058
if tools is not None and len(tools) > 0:
5159
messages.insert(0, {"role": "system"})
52-
messages[0]["tools"] = tools
60+
messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key]
5361

5462
# Historical reasoning content is dropped when a new user message is introduced
5563
drop_thinking = messages[-1]["role"] == "user"
5664

5765
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
5866
prompt_str = encode_messages(messages, **encode_config) # type: ignore
67+
68+
if kwargs.get("tokenize", True):
69+
tokenizer_kwargs = {
70+
k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs
71+
}
72+
return self.encode(
73+
prompt_str,
74+
add_special_tokens=False,
75+
**tokenizer_kwargs,
76+
)
77+
5978
return prompt_str
6079

6180
def num_special_tokens_to_add(self) -> int:

vllm/tokenizers/hf.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,18 @@
33
import contextlib
44
import copy
55
from pathlib import Path
6-
from typing import TYPE_CHECKING
6+
from typing import TypeAlias
77

8-
from transformers import AutoTokenizer
8+
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
99

1010
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
1111

1212
from .protocol import TokenizerLike
13-
from .registry import TokenizerRegistry
1413

15-
if TYPE_CHECKING:
16-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
14+
HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
1715

1816

19-
def get_cached_tokenizer(
20-
tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast",
21-
) -> TokenizerLike:
17+
def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
2218
"""
2319
By default, transformers will recompute multiple tokenizer properties
2420
each time they are called, leading to a significant slowdown.
@@ -65,11 +61,10 @@ def __reduce__(self):
6561
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
6662

6763
cached_tokenizer.__class__ = CachedTokenizer
68-
return cached_tokenizer # type: ignore
64+
return cached_tokenizer
6965

7066

71-
@TokenizerRegistry.register("hf")
72-
class HfTokenizer(TokenizerLike):
67+
class CachedHfTokenizer(TokenizerLike):
7368
@classmethod
7469
def from_pretrained(
7570
cls,
@@ -79,7 +74,7 @@ def from_pretrained(
7974
revision: str | None = None,
8075
download_dir: str | None = None,
8176
**kwargs,
82-
) -> "TokenizerLike":
77+
) -> HfTokenizer:
8378
try:
8479
tokenizer = AutoTokenizer.from_pretrained(
8580
path_or_repo_id,

vllm/tokenizers/mistral.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from pathlib import Path
44
from typing import TYPE_CHECKING, Any, cast
55

6+
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
7+
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
68
from vllm.logger import init_logger
79

810
from .protocol import TokenizerLike
9-
from .registry import TokenizerRegistry
1011

1112
if TYPE_CHECKING:
1213
from mistral_common.protocol.instruct.request import (
@@ -15,9 +16,6 @@
1516
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
1617
from transformers import BatchEncoding
1718

18-
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
19-
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
20-
2119
try:
2220
# Transformers v5
2321
from transformers.tokenization_mistral_common import MistralCommonBackend
@@ -201,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
201199
return tokenizer.unk_id
202200

203201

204-
@TokenizerRegistry.register("mistral")
205202
class MistralTokenizer(TokenizerLike):
206203
@classmethod
207204
def from_pretrained(

vllm/tokenizers/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def apply_chat_template(
9797
messages: list["ChatCompletionMessageParam"],
9898
tools: list[dict[str, Any]] | None = None,
9999
**kwargs,
100-
) -> list[int]:
100+
) -> str | list[int]:
101101
raise NotImplementedError
102102

103103
def convert_tokens_to_string(self, tokens: list[str]) -> str:

0 commit comments

Comments
 (0)