Skip to content

Commit db3430f

Browse files
authored
[None][feat] Support VLM part for Mistral Large 3 (#10188)
Signed-off-by: bhsueh <[email protected]>
1 parent 7e4cef9 commit db3430f

File tree

9 files changed

+558
-41
lines changed

9 files changed

+558
-41
lines changed

examples/models/core/mistral_large_3/README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,21 @@ export mistral_large_3_model_path=<mistral_large_3_model_path>
77
export mistral_large_3_eagle_model_path=<mistral_large_3_eagle_model_path>
88
```
99

10+
## Multimodal run
11+
12+
* Run the Mistral Large V3 by `quickstart_multimodal.py`
13+
14+
```bash
15+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_multimodal.py \
16+
--model_dir ${mistral_large_3_model_path} \
17+
--tp_size 4 \
18+
--moe_ep_size 4 \
19+
--max_tokens 100 \
20+
--checkpoint_format mistral \
21+
--model_type mistral_large_3 \
22+
--moe_backend TRTLLM
23+
```
24+
1025
## LLM-only run
1126

1227
* Run the Mistral Large V3 by `quickstart_advanced.py`
@@ -44,9 +59,6 @@ echo "
4459
backend: pytorch
4560
tensor_parallel_size: 4
4661
moe_expert_parallel_size: 4
47-
enable_attention_dp: false
48-
kv_cache_config:
49-
enable_block_reuse: true
5062
checkpoint_format: mistral
5163
" > serve.yml
5264
mpirun -n 1 --allow-run-as-root --oversubscribe python3 -m tensorrt_llm.commands.serve serve \
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
from pathlib import Path
2+
from typing import Union
3+
4+
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
5+
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer
6+
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
7+
from transformers.tokenization_mistral_common import (
8+
MistralCommonTokenizer as TransformersMistralTokenizer,
9+
)
10+
11+
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
12+
from tensorrt_llm.logger import logger
13+
14+
15+
# Adapted from:
16+
# https://github.com/vllm-project/vllm/blob/8e67b2557aae7204c697d7a5c61e00754da465be/vllm/transformers_utils/tokenizers/mistral.py#L166
17+
class MistralTokenizer(TransformersTokenizer):
18+
def __init__(self, tokenizer: "TransformersMistralTokenizer"):
19+
self.transformers_tokenizer = tokenizer
20+
self.mistral = tokenizer.tokenizer
21+
self.instruct = self.mistral.instruct_tokenizer
22+
self.tokenizer = self.instruct.tokenizer
23+
24+
_mistral_version_str = str(self.tokenizer.version.value)
25+
self.version: int = int(_mistral_version_str.split("v")[-1])
26+
27+
self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
28+
self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
29+
if not (self.is_tekken or self.is_spm):
30+
raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")
31+
32+
# Reverse order to ensure that the lowest token id is kept.
33+
self._vocab_dict = {
34+
self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
35+
for i in range(self.transformers_tokenizer.vocab_size - 1, -1, -1)
36+
}
37+
# Sort the dict for convenience
38+
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
39+
40+
# Cache special tokens for faster access.
41+
self._special_token_ids = self._get_special_token_ids()
42+
self._special_token_ids_set = set(self._special_token_ids)
43+
self._special_tokens = self._get_special_tokens(self._special_token_ids)
44+
self._special_tokens_set = set(self._special_tokens)
45+
46+
# Vocab sorted by token id.
47+
self._vocab = self.tokenizer._vocab
48+
self._max_token_id = self.transformers_tokenizer.vocab_size - 1
49+
50+
self._all_special_tokens_set = set(self.all_special_tokens)
51+
52+
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
53+
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
54+
55+
return [
56+
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
57+
for i in all_special_ids
58+
]
59+
60+
# the following attributes are set to fit vLLM's design and are used
61+
# by the structured output backends.
62+
@property
63+
def all_special_tokens_extended(self) -> list[str]:
64+
return self.all_special_tokens
65+
66+
@property
67+
def all_special_tokens(self) -> list[str]:
68+
return self._special_tokens
69+
70+
@property
71+
def all_special_ids(self) -> list[int]:
72+
return self._special_token_ids
73+
74+
@classmethod
75+
def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
76+
if Path(pretrained_model_dir).is_file():
77+
tokenizer = TransformersMistralTokenizer(tokenizer_path=pretrained_model_dir)
78+
else:
79+
tokenizer = TransformersMistralTokenizer.from_pretrained(pretrained_model_dir)
80+
return cls(tokenizer)
81+
82+
def _get_special_token_ids(self) -> list[int]:
83+
from mistral_common.tokens.tokenizers.sentencepiece import SentencePieceTokenizer
84+
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
85+
86+
if self.is_tekken:
87+
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
88+
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
89+
elif self.is_spm:
90+
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(self.tokenizer)
91+
special_ids = self.tokenizer._control_tokens
92+
else:
93+
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
94+
return sorted(special_ids)
95+
96+
@property
97+
def bos_token_id(self) -> int:
98+
return self.tokenizer.bos_id
99+
100+
@property
101+
def eos_token_id(self) -> int:
102+
return self.tokenizer.eos_id
103+
104+
@property
105+
def sep_token(self) -> str:
106+
raise NotImplementedError()
107+
108+
@property
109+
def pad_token(self) -> str:
110+
return self.transformers_tokenizer.pad_token
111+
112+
@property
113+
def pad_token_id(self) -> int:
114+
return self.transformers_tokenizer.pad_token_id
115+
116+
def __call__(self, text: str, *args, **kwargs) -> any:
117+
return self.transformers_tokenizer(text=text, *args, **kwargs)
118+
119+
@property
120+
def name_or_path(self) -> str:
121+
return self.transformers_tokenizer.name_or_path
122+
123+
def batch_encode_plus(self, texts: list[str], *args, **kwargs) -> dict:
124+
raise NotImplementedError
125+
126+
def get_chat_template(
127+
self, chat_template: str | None = None, tools: list[dict] | None = None
128+
) -> str:
129+
raise NotImplementedError
130+
131+
def clean_up_tokenization(self, out_string: str) -> str:
132+
raise NotImplementedError
133+
134+
@property
135+
def is_fast(self) -> bool:
136+
return True
137+
138+
def get_added_vocab(self) -> dict[str, int]:
139+
# Mistral tokenizers have no added vocabulary
140+
return {}
141+
142+
def _tekken_token_to_id(self, tokenizer: "Tekkenizer", t: str | bytes) -> int:
143+
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
144+
145+
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
146+
shift = tokenizer.num_special_tokens
147+
try:
148+
return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
149+
except KeyError:
150+
t_str = t_bytes.decode("utf-8")
151+
if t_str in tokenizer._special_tokens_reverse_vocab:
152+
return tokenizer._special_tokens_reverse_vocab[t_str]
153+
logger.warning("Failed to convert token %s to id, replacing with <unk>", t_bytes)
154+
return tokenizer.unk_id
155+
156+
def _is_special_token_id(self, token_id: int) -> bool:
157+
return token_id in self._special_token_ids_set
158+
159+
def convert_tokens_to_string(
160+
self,
161+
tokens: list[str],
162+
skip_special_tokens: bool = False,
163+
spaces_between_special_tokens: bool = True,
164+
) -> str:
165+
to_decode_special_tokens = {SpecialTokens.tool_calls}
166+
if self.is_tekken:
167+
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
168+
tokens = [
169+
t
170+
for t in tokens
171+
if (t in to_decode_special_tokens or t not in self._special_tokens_set)
172+
]
173+
174+
if any(isinstance(t, bytes) for t in tokens):
175+
# we need to encode and decode all tokens again
176+
ids = [self._tekken_token_to_id(self.tokenizer, t) for t in tokens]
177+
# We filtered unwanted special tokens before
178+
# so we can decode the rest.
179+
decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
180+
else:
181+
decoded = "".join(tokens)
182+
else:
183+
# make sure certain special tokens like Tool calls are
184+
# not decoded
185+
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(self.tokenizer)
186+
187+
regular_tokens: list[str] = []
188+
decoded_list: list[str] = []
189+
decoded = ""
190+
191+
for token in tokens:
192+
if token in to_decode_special_tokens:
193+
if regular_tokens:
194+
decoded_list.append(
195+
self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
196+
)
197+
regular_tokens = []
198+
decoded_list.append(token)
199+
else:
200+
regular_tokens.append(token)
201+
202+
if regular_tokens:
203+
decoded_list.append(
204+
self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
205+
)
206+
decoded = "".join(decoded_list)
207+
208+
return decoded
209+
210+
def encode(
211+
self,
212+
text: str,
213+
truncation: bool | None = None,
214+
max_length: int | None = None,
215+
add_special_tokens: bool | None = None,
216+
) -> list[int]:
217+
if add_special_tokens is not None:
218+
return self.transformers_tokenizer.encode(
219+
text,
220+
truncation=truncation,
221+
max_length=max_length,
222+
add_special_tokens=add_special_tokens,
223+
)
224+
else:
225+
encoded = self.tokenizer.encode(text, bos=True, eos=False)
226+
227+
if truncation is not False and max_length is not None:
228+
return encoded[:max_length]
229+
else:
230+
return encoded
231+
232+
def decode(
233+
self, token_ids: list[int] | int, skip_special_tokens: bool = True, *args, **kwargs
234+
) -> str:
235+
return self.transformers_tokenizer.decode(
236+
token_ids, skip_special_tokens=skip_special_tokens
237+
)
238+
239+
def convert_ids_to_tokens(
240+
self,
241+
ids: list[int],
242+
skip_special_tokens: bool = True,
243+
) -> list[str]:
244+
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
245+
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
246+
247+
if not skip_special_tokens:
248+
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
249+
250+
non_skip_special_tokens_ids = {
251+
self.tokenizer.get_control_token(SpecialTokens.tool_calls),
252+
}
253+
if isinstance(self.instruct, InstructTokenizerV13):
254+
if self.instruct.BEGIN_THINK:
255+
non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
256+
if self.instruct.END_THINK:
257+
non_skip_special_tokens_ids.add(self.instruct.END_THINK)
258+
259+
ids_kept = [
260+
i for i in ids if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
261+
]
262+
263+
# We filtered unwanted special tokens so we can decode the rest.
264+
tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
265+
266+
if any("�" in t for t in tokens) and self.is_tekken:
267+
# if a decoded token contains the replacement character, then the
268+
# token has an incomplete UTF-8 character so we must use bytes
269+
# See: https://github.com/vllm-project/vllm/pull/8640
270+
# https://github.com/vllm-project/vllm/pull/9625
271+
# if underlying tokenizer is sentencepiece, we just add "�".
272+
# We filtered unwanted special tokens so we can decode the rest.
273+
tokens = [
274+
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
275+
if token_id not in self._special_token_ids_set
276+
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
277+
for token_id in ids_kept
278+
]
279+
280+
return tokens
281+
282+
@property
283+
def vocab_size(self) -> int:
284+
return len(self._vocab_dict)
285+
286+
@property
287+
def clean_up_tokenization_spaces(self):
288+
return False
289+
290+
def hf_decode_incrementally(
291+
self,
292+
token_ids: list[int],
293+
prev_text: str | None = None,
294+
states: dict | None = None,
295+
*,
296+
skip_special_tokens: bool = False,
297+
clean_up_tokenization_spaces: bool | None = None,
298+
) -> tuple[str, dict]:
299+
raise NotImplementedError
300+
301+
def apply_chat_template(
302+
self, conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], *args, **kwargs
303+
) -> Union[str, list[int], list[str], list[list[int]]]:
304+
raise NotImplementedError

0 commit comments

Comments
 (0)