Skip to content

Commit 885ca6d

Browse files
[Misc] Fix warnings for mistral model (vllm-project#23552)
Signed-off-by: zjy0516 <[email protected]> Signed-off-by: Jiangyun Zhu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2d0afcc commit 885ca6d

File tree

3 files changed

+31
-23
lines changed

3 files changed

+31
-23
lines changed

vllm/model_executor/models/pixtral.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from mistral_common.protocol.instruct.request import ChatCompletionRequest
1616
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
1717
from PIL import Image
18-
from transformers import PixtralVisionConfig, TensorType
18+
from transformers import BatchFeature, PixtralVisionConfig, TensorType
1919
from transformers.image_utils import ImageInput
2020
from transformers.models.pixtral.image_processing_pixtral import (
2121
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
@@ -163,10 +163,12 @@ def __call__(
163163
images_processed.append(image_processed)
164164
images_tokens.append(image_tokens)
165165

166-
return {
167-
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
168-
"images": images_processed,
169-
}
166+
return BatchFeature({
167+
"input_ids":
168+
torch.cat(images_tokens)[None].expand(len(text), -1),
169+
"images":
170+
images_processed,
171+
})
170172

171173

172174
class PixtralProcessingInfo(BaseProcessingInfo):

vllm/model_executor/models/voxtral.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from mistral_common.protocol.instruct.request import ChatCompletionRequest
1818
from mistral_common.protocol.transcription.request import TranscriptionRequest
1919
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
20-
from transformers import TensorType, WhisperConfig
20+
from transformers import BatchFeature, TensorType, WhisperConfig
2121
from transformers.tokenization_utils_base import TextInput
2222

2323
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
@@ -156,10 +156,12 @@ def __call__(
156156
audios_tokens.append(torch.tensor(audio_tokens))
157157
audios_processed.append(torch.tensor(audio))
158158

159-
return {
160-
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
161-
"audio_arrays": audios_processed,
162-
}
159+
return BatchFeature({
160+
"input_ids":
161+
torch.cat(audios_tokens)[None].expand(len(text), -1),
162+
"audio_arrays":
163+
audios_processed,
164+
})
163165

164166

165167
class VoxtralProcessingInfo(BaseProcessingInfo):

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -204,18 +204,16 @@ def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
204204
self.version: int = int(_mistral_version_str.split("v")[-1])
205205

206206
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
207-
from mistral_common.tokens.tokenizers.tekken import (
208-
SpecialTokenPolicy, Tekkenizer)
207+
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
208+
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
209+
209210
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
210211
from mistral_common.tokens.tokenizers.sentencepiece import (
211212
SentencePieceTokenizer)
212213
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
213-
if self.is_tekken:
214-
# Make sure special tokens will not raise
215-
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
216-
elif self.is_spm:
217-
pass
218-
else:
214+
self._special_token_policy = (SpecialTokenPolicy.IGNORE
215+
if self.is_tekken else None)
216+
if not (self.is_tekken or self.is_spm):
219217
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
220218

221219
self._vocab = tokenizer_.vocab()
@@ -430,7 +428,8 @@ def _token_to_id(t: str):
430428
return self.tokenizer.unk_id
431429

432430
ids = [_token_to_id(t) for t in tokens]
433-
decoded = self.tokenizer.decode(ids)
431+
decoded = self.tokenizer.decode(ids,
432+
self._special_token_policy)
434433
else:
435434
decoded = "".join(tokens)
436435
else:
@@ -444,15 +443,17 @@ def _token_to_id(t: str):
444443
if token in special_tokens:
445444
if regular_tokens:
446445
decoded_list.append(
447-
self.tokenizer.decode(regular_tokens))
446+
self.tokenizer.decode(regular_tokens,
447+
self._special_token_policy))
448448
regular_tokens = []
449449
decoded_list.append(token)
450450
else:
451451
regular_tokens.append(token)
452452

453453
if regular_tokens:
454454
decoded_list.append(
455-
self.tokenizer.decode(regular_tokens)) # type: ignore
455+
self.tokenizer.decode(regular_tokens,
456+
self._special_token_policy))
456457

457458
decoded = ''.join(decoded_list)
458459

@@ -470,7 +471,7 @@ def decode(self,
470471

471472
if isinstance(ids, int):
472473
ids = [ids]
473-
return self.tokenizer.decode(ids)
474+
return self.tokenizer.decode(ids, self._special_token_policy)
474475

475476
def convert_ids_to_tokens(
476477
self,
@@ -511,6 +512,9 @@ def convert_ids_to_tokens(
511512
# See: https://github.com/vllm-project/vllm/pull/8640
512513
# https://github.com/vllm-project/vllm/pull/9625
513514
# if underlying tokenizeir is sentencepiece, we just add "�"
514-
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
515+
tokens = [
516+
self.tokenizer.id_to_byte_piece(id, self._special_token_policy)
517+
for id in ids
518+
]
515519

516520
return tokens

0 commit comments

Comments
 (0)