@@ -336,14 +336,12 @@ def __init__(
336336 if tokenizer is not None :
337337 self ._tokenizer = tokenizer
338338 else :
339- try :
340- self ._tokenizer = AutoTokenizer .from_pretrained (
341- model_path ,
342- config = config ,
343- use_fast = self .use_fast ,
344- trust_remote_code = trust_remote_code )
345- except ValueError :
346- self ._tokenizer = MistralTokenizer .from_pretrained (model_path )
339+ self ._tokenizer = AutoTokenizer .from_pretrained (
340+ model_path ,
341+ config = config ,
342+ use_fast = self .use_fast ,
343+ trust_remote_code = trust_remote_code )
344+
347345 self ._model_path = model_path
348346 if isinstance (self ._tokenizer , MistralTokenizer ):
349347 self ._processor = MistralCommonImageProcessor (
@@ -353,6 +351,9 @@ def __init__(
353351 model_path ,
354352 use_fast = self .use_fast ,
355353 trust_remote_code = trust_remote_code )
354+
355+ logger .debug (f"Mistral3InputProcessor: using { type (self ._processor )} preprocessor" )
356+ logger .debug (f"Mistral3InputProcessor: using { type (self ._tokenizer )} tokenizer" )
356357
357358 @property
358359 def config (self ) -> PretrainedConfig :
@@ -443,6 +444,37 @@ def get_mm_special_token_ids(self) -> torch.Tensor:
443444 self .processor .image_end_token_id ,
444445 ])
445446
447+ class MistralCommonInputProcessor (Mistral3InputProcessor ):
448+ def __init__ (
449+ self ,
450+ model_path : str ,
451+ config : PretrainedConfig ,
452+ tokenizer : Optional [AutoTokenizer ],
453+ trust_remote_code : bool = False ,
454+ ** kwargs ,
455+ ):
456+ tokenizer = self .load_tokenizer (model_path , config = config )
457+ super ().__init__ (model_path = model_path ,
458+ config = config ,
459+ tokenizer = tokenizer ,
460+ ** kwargs )
461+
462+ @staticmethod
463+ def load_tokenizer (model_path : str , config : PretrainedConfig , checkpoint_format : Optional [str ] = "mistral_large_3" ):
464+ if checkpoint_format == "mistral_large_3" :
465+ try :
466+ return MistralTokenizer .from_pretrained (model_path )
467+
468+ except ValueError :
469+ logger .info (f"Could not load mistral-common tokenizer from { model_path } , falling back to HuggingFace" )
470+
471+ tokenizer = AutoTokenizer .from_pretrained (
472+ model_path ,
473+ config = config ,
474+ use_fast = True ,
475+ trust_remote_code = True )
476+ return tokenizer
477+
446478
447479class Mistral3Gate (nn .Module ):
448480
@@ -478,26 +510,27 @@ def load_weights(self, weights: List[Dict]):
478510@register_auto_model ("Mistral3ForConditionalGeneration" )
479511@register_auto_model ("PixtralForConditionalGeneration" )
480512@register_input_processor (
481- Mistral3InputProcessor ,
482- model_type = "mistral3_hf " ,
513+ MistralCommonInputProcessor ,
514+ model_type = "mistral3 " ,
483515 placeholder_metadata = MultimodalPlaceholderMetadata (
484516 placeholder_map = {
517+ # NOTE: mistral-common uses the tokenizer to set placeholders, this will be ignored
485518 "image" : "[IMG]" ,
486519 },
487- # NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
488- # Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
489- # src/mistral_common/tokens/tokenizers/base.py#L326
490- # However, accuracy tests show that the model generates higher quality output when the image
491- # precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
492520 placeholder_placement = MultimodalPlaceholderPlacement .BEFORE_TEXT ,
493521 ))
494522@register_input_processor (
495523 Mistral3InputProcessor ,
496- model_type = "mistral3 " ,
524+ model_type = "mistral3_hf " ,
497525 placeholder_metadata = MultimodalPlaceholderMetadata (
498526 placeholder_map = {
499527 "image" : "[IMG]" ,
500528 },
529+ # NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
530+ # Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
531+ # src/mistral_common/tokens/tokenizers/base.py#L326
532+ # However, accuracy tests show that the model generates higher quality output when the image
533+ # precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
501534 placeholder_placement = MultimodalPlaceholderPlacement .BEFORE_TEXT ,
502535 ))
503536class Mistral3VLM (PreTrainedModel ):
0 commit comments