Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 7015417

Browse files
[Bugfix] Add missing attributes in mistral tokenizer (vllm-project#8364)
1 parent aea02f3 commit 7015417

File tree

2 files changed

+63
-32
lines changed

2 files changed

+63
-32
lines changed

vllm/entrypoints/chat_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,11 +519,14 @@ def apply_hf_chat_template(
519519
def apply_mistral_chat_template(
520520
tokenizer: MistralTokenizer,
521521
messages: List[ChatCompletionMessageParam],
522-
chat_template: Optional[str],
522+
chat_template: Optional[str] = None,
523523
**kwargs: Any,
524524
) -> List[int]:
525+
if chat_template is not None:
526+
logger.warning(
527+
"'chat_template' cannot be overridden for mistral tokenizer.")
528+
525529
return tokenizer.apply_chat_template(
526530
messages=messages,
527-
chat_template=chat_template,
528531
**kwargs,
529532
)

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,25 @@ class MistralTokenizer:
4545
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
4646
self.mistral = tokenizer
4747
self.instruct = tokenizer.instruct_tokenizer
48-
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
4948

50-
self.vocab_size = len(self.tokenizer.vocab())
51-
52-
assert isinstance(self.tokenizer,
53-
(Tekkenizer, SentencePieceTokenizer)), type(
54-
self.tokenizer)
55-
56-
if (is_tekken := isinstance(self.tokenizer, Tekkenizer)):
49+
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
50+
if isinstance(tokenizer_, Tekkenizer):
5751
# Make sure special tokens will not raise
58-
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
59-
60-
self._is_tekken = is_tekken
52+
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
53+
54+
self._vocab = {
55+
token: idx
56+
for idx, token in enumerate(tokenizer_.vocab())
57+
}
58+
elif isinstance(tokenizer_, SentencePieceTokenizer):
59+
self._vocab = {
60+
token: idx
61+
for idx, token in enumerate(tokenizer_.vocab())
62+
}
63+
else:
64+
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
6165

62-
# the following attributes are set to fit VLLM's design
63-
self.is_fast = True
64-
self.chat_template = True
65-
self.all_special_ids: List[Any] = []
66-
self.all_special_tokens: List[Any] = []
67-
self.all_special_tokens_extended: List[Any] = []
66+
self.tokenizer = tokenizer_
6867

6968
@classmethod
7069
def from_pretrained(cls,
@@ -102,6 +101,38 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
102101
revision=revision)
103102
return tokenizer_file
104103

104+
# the following attributes are set to fit VLLM's design
105+
@property
106+
def all_special_tokens_extended(self) -> List[str]:
107+
return []
108+
109+
@property
110+
def all_special_tokens(self) -> List[str]:
111+
return []
112+
113+
@property
114+
def all_special_ids(self) -> List[int]:
115+
return []
116+
117+
@property
118+
def bos_token_id(self) -> int:
119+
return self.tokenizer.bos_id
120+
121+
@property
122+
def eos_token_id(self) -> int:
123+
return self.tokenizer.eos_id
124+
125+
@property
126+
def is_fast(self) -> bool:
127+
return True
128+
129+
@property
130+
def vocab_size(self) -> int:
131+
return len(self._vocab)
132+
133+
def __len__(self) -> int:
134+
return self.vocab_size
135+
105136
def __call__(
106137
self,
107138
prompt: str,
@@ -117,9 +148,12 @@ def __call__(
117148

118149
return Encoding(input_ids=input_ids)
119150

120-
def get_added_vocab(self) -> List[str]:
151+
def get_vocab(self) -> Dict[str, int]:
152+
return self._vocab
153+
154+
def get_added_vocab(self) -> Dict[str, int]:
121155
# Mistral tokenizers have no added vocabulary
122-
return []
156+
return {}
123157

124158
def encode(self, prompt: str) -> List[int]:
125159
# `encode` should only be used for prompt completion
@@ -141,7 +175,7 @@ def apply_chat_template(self,
141175
return encoded.tokens
142176

143177
def convert_tokens_to_string(self, tokens: List[str]) -> str:
144-
if self._is_tekken:
178+
if isinstance(self.tokenizer, Tekkenizer):
145179
return "".join(tokens)
146180
else:
147181
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
@@ -151,14 +185,11 @@ def decode(self, ids: Union[List[int], int]) -> str:
151185
ids = [ids]
152186
return self.tokenizer.decode(ids)
153187

154-
@property
155-
def eos_token_id(self):
156-
return self.tokenizer.eos_id
157-
158188
def convert_ids_to_tokens(
159-
self,
160-
ids: List[int],
161-
skip_special_tokens: Optional[bool] = True) -> List[str]:
189+
self,
190+
ids: List[int],
191+
skip_special_tokens: bool = True,
192+
) -> List[str]:
162193
# TODO(Patrick) - potentially allow special tokens to not be skipped
163194
assert (
164195
skip_special_tokens
@@ -170,6 +201,3 @@ def convert_ids_to_tokens(
170201

171202
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
172203
return tokens
173-
174-
def __len__(self):
175-
return self.vocab_size

0 commit comments

Comments
 (0)