Skip to content

Commit df98854

Browse files
committed
refactor and docs
Signed-off-by: Olya Kozlova <okozlova@nvidia.com>
1 parent 506a8fe commit df98854

File tree

4 files changed

+77
-29
lines changed

4 files changed

+77
-29
lines changed

examples/llm-api/quickstart_multimodal.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,17 @@ def main():
272272
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
273273
if args.media is None:
274274
args.media = example_medias_and_prompts[args.modality]["media"]
275+
276+
#FIXME WAR for mistral-common processors
277+
keep_source_media=(args.model_type=="mistral3" and args.checkpoint_format == "mistral_large_3")
275278
inputs = default_multimodal_input_loader(
276279
tokenizer=llm.tokenizer,
277280
model_dir=str(llm._hf_model_dir),
278281
model_type=model_type,
279282
modality=args.modality,
280283
prompts=args.prompt,
281284
media=args.media,
282-
keep_source_media=(args.checkpoint_format == "mistral_large_3"),
285+
keep_source_media=keep_source_media,
283286
processor=getattr(llm, "input_processor", None),
284287
image_data_format=image_format,
285288
num_frames=args.num_frames,

examples/models/core/mistral_large_3/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,22 @@ export mistral_large_3_model_path=<mistral_large_3_model_path>
77
export mistral_large_3_eagle_model_path=<mistral_large_3_eagle_model_path>
88
```
99

10+
## Multimodal run
11+
12+
* Run the Mistral Large V3 by `quickstart_multimodal.py`
13+
14+
```bash
15+
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_multimodal.py \
16+
--model_dir ${mistral_large_3_model_path} \
17+
--tp_size 4 \
18+
--moe_ep_size 4 \
19+
--max_tokens 100 \
20+
--checkpoint_format mistral_large_3 \
21+
--model_type mistral3 \
22+
--kv_cache_fraction 0.25 \
23+
--moe_backend TRTLLM # optional
24+
```
25+
1026
## LLM-only run
1127

1228
* Run the Mistral Large V3 by `quickstart_advanced.py`

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,12 @@ def __init__(
336336
if tokenizer is not None:
337337
self._tokenizer = tokenizer
338338
else:
339-
try:
340-
self._tokenizer = AutoTokenizer.from_pretrained(
341-
model_path,
342-
config=config,
343-
use_fast=self.use_fast,
344-
trust_remote_code=trust_remote_code)
345-
except ValueError:
346-
self._tokenizer = MistralTokenizer.from_pretrained(model_path)
339+
self._tokenizer = AutoTokenizer.from_pretrained(
340+
model_path,
341+
config=config,
342+
use_fast=self.use_fast,
343+
trust_remote_code=trust_remote_code)
344+
347345
self._model_path = model_path
348346
if isinstance(self._tokenizer, MistralTokenizer):
349347
self._processor = MistralCommonImageProcessor(
@@ -353,6 +351,9 @@ def __init__(
353351
model_path,
354352
use_fast=self.use_fast,
355353
trust_remote_code=trust_remote_code)
354+
355+
logger.debug(f"Mistral3InputProcessor: using {type(self._processor)} preprocessor")
356+
logger.debug(f"Mistral3InputProcessor: using {type(self._tokenizer)} tokenizer")
356357

357358
@property
358359
def config(self) -> PretrainedConfig:
@@ -443,6 +444,37 @@ def get_mm_special_token_ids(self) -> torch.Tensor:
443444
self.processor.image_end_token_id,
444445
])
445446

447+
class MistralCommonInputProcessor(Mistral3InputProcessor):
448+
def __init__(
449+
self,
450+
model_path: str,
451+
config: PretrainedConfig,
452+
tokenizer: Optional[AutoTokenizer],
453+
trust_remote_code: bool = False,
454+
**kwargs,
455+
):
456+
tokenizer = self.load_tokenizer(model_path, config=config)
457+
super().__init__(model_path=model_path,
458+
config=config,
459+
tokenizer=tokenizer,
460+
**kwargs)
461+
462+
@staticmethod
463+
def load_tokenizer(model_path: str, config: PretrainedConfig, checkpoint_format: Optional[str] = "mistral_large_3"):
464+
if checkpoint_format == "mistral_large_3":
465+
try:
466+
return MistralTokenizer.from_pretrained(model_path)
467+
468+
except ValueError:
469+
logger.info(f"Could not load mistral-common tokenizer from {model_path}, falling back to HuggingFace")
470+
471+
tokenizer = AutoTokenizer.from_pretrained(
472+
model_path,
473+
config=config,
474+
use_fast=True,
475+
trust_remote_code=True)
476+
return tokenizer
477+
446478

447479
class Mistral3Gate(nn.Module):
448480

@@ -478,26 +510,27 @@ def load_weights(self, weights: List[Dict]):
478510
@register_auto_model("Mistral3ForConditionalGeneration")
479511
@register_auto_model("PixtralForConditionalGeneration")
480512
@register_input_processor(
481-
Mistral3InputProcessor,
482-
model_type="mistral3_hf",
513+
MistralCommonInputProcessor,
514+
model_type="mistral3",
483515
placeholder_metadata=MultimodalPlaceholderMetadata(
484516
placeholder_map={
517+
# NOTE: mistral-common uses the tokenizer to set placeholders, this will be ignored
485518
"image": "[IMG]",
486519
},
487-
# NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
488-
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
489-
# src/mistral_common/tokens/tokenizers/base.py#L326
490-
# However, accuracy tests show that the model generates higher quality output when the image
491-
# precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
492520
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
493521
))
494522
@register_input_processor(
495523
Mistral3InputProcessor,
496-
model_type="mistral3",
524+
model_type="mistral3_hf",
497525
placeholder_metadata=MultimodalPlaceholderMetadata(
498526
placeholder_map={
499527
"image": "[IMG]",
500528
},
529+
# NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
530+
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
531+
# src/mistral_common/tokens/tokenizers/base.py#L326
532+
# However, accuracy tests show that the model generates higher quality output when the image
533+
# precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
501534
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
502535
))
503536
class Mistral3VLM(PreTrainedModel):

tensorrt_llm/inputs/registry.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import torch
1010
from PIL import Image
1111
from torch import Tensor, nn
12-
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
13-
PreTrainedTokenizerBase)
12+
from transformers import (AutoProcessor, PretrainedConfig, PreTrainedTokenizerBase)
1413

1514
import tensorrt_llm
1615

@@ -595,16 +594,13 @@ def create_input_processor(
595594

596595
# FIXME support both HF and mistral-common paths in a better way
597596
if tokenizer is None:
598-
try:
599-
tokenizer = AutoTokenizer.from_pretrained(
600-
model_path_or_dir,
601-
config=config,
602-
use_fast=True,
603-
trust_remote_code=True)
604-
605-
except ValueError:
606-
from tensorrt_llm.llmapi.tokenizer import MistralTokenizer
607-
tokenizer = MistralTokenizer.from_pretrained(model_path_or_dir)
597+
from tensorrt_llm._torch.models.modeling_mistral import \
598+
MistralCommonInputProcessor
599+
tokenizer = MistralCommonInputProcessor.load_tokenizer(
600+
model_path_or_dir, config=None)
601+
602+
print(f"loaded tokenizer: {type(tokenizer)}")
603+
608604
else:
609605
logger.debug(
610606
f"checkpoint_format={checkpoint_format}; skipping HF config load.")

0 commit comments

Comments
 (0)