@@ -89,13 +89,16 @@ class ModelBase:
8989 block_count : int
9090 tensor_map : gguf .TensorNameMap
9191
92+ # Mistral format specifics
9293 is_mistral_format : bool = False
94+ disable_mistral_community_chat_template : bool = False
9395
9496 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , * , is_big_endian : bool = False ,
9597 use_temp_file : bool = False , eager : bool = False ,
9698 metadata_override : Path | None = None , model_name : str | None = None ,
9799 split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False ,
98- small_first_shard : bool = False , hparams : dict [str , Any ] | None = None , remote_hf_model_id : str | None = None ):
100+ small_first_shard : bool = False , hparams : dict [str , Any ] | None = None , remote_hf_model_id : str | None = None ,
101+ disable_mistral_community_chat_template : bool = False ):
99102 if type (self ) is ModelBase or \
100103 type (self ) is TextModel or \
101104 type (self ) is MmprojModel :
@@ -147,6 +150,9 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
147150 self .gguf_writer = gguf .GGUFWriter (path = None , arch = gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = self .use_temp_file ,
148151 split_max_tensors = split_max_tensors , split_max_size = split_max_size , dry_run = dry_run , small_first_shard = small_first_shard )
149152
153+ # Mistral specific
154+ self .disable_mistral_community_chat_template = disable_mistral_community_chat_template
155+
150156 @classmethod
151157 def add_prefix_to_filename (cls , path : Path , prefix : str ) -> Path :
152158 stem , suffix = path .stem , path .suffix
@@ -2011,8 +2017,17 @@ def _set_vocab_mistral(self):
20112017
20122018 template_dir = Path (__file__ ).parent / "models/templates/"
20132019
2014- template = MistralModel .get_community_chat_template (vocab , template_dir )
2015- self .gguf_writer .add_chat_template (template )
2020+ if not self .is_mistral_format or not self .disable_mistral_community_chat_template :
2021+ # Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
2022+ if self .is_mistral_format :
2023+ logger .info (
2024+ "Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. "
2025+ "Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
2026+ )
2027+ template = MistralModel .get_community_chat_template (vocab , template_dir , self .is_mistral_format )
2028+ self .gguf_writer .add_chat_template (template )
2029+ else :
2030+ logger .info ("Not using a Mistral community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`." )
20162031
20172032 def set_vocab (self ):
20182033 if self .is_mistral_format :
@@ -8422,7 +8437,7 @@ class MistralModel(LlamaModel):
84228437 undo_permute = False
84238438
84248439 @staticmethod
8425- def get_community_chat_template (vocab : MistralVocab , templates_dir : Path ):
8440+ def get_community_chat_template (vocab : MistralVocab , templates_dir : Path , is_mistral_format : bool ):
84268441 assert TokenizerVersion is not None , "mistral_common is not installed"
84278442 assert isinstance (vocab .tokenizer , (Tekkenizer , SentencePieceTokenizer )), (
84288443 f"Expected Tekkenizer or SentencePieceTokenizer, got { type (vocab .tokenizer )} "
@@ -8443,7 +8458,13 @@ def get_community_chat_template(vocab: MistralVocab, templates_dir: Path):
84438458 elif vocab .tokenizer .version == TokenizerVersion .v13 :
84448459 template_file = "unsloth-mistral-Devstral-Small-2507.jinja"
84458460 else :
8446- raise ValueError (f"Unknown tokenizer type: { vocab .tokenizer_type } and version { vocab .tokenizer .version } " )
8461+ err_message = f"Unknown tokenizer type: { vocab .tokenizer_type } and version { vocab .tokenizer .version } "
8462+ if is_mistral_format :
8463+ err_message += (
8464+ " . Please pass --disable-mistral-community-chat-template argument to the CLI "
8465+ "if you want to skip this error and use the Mistral official `mistral-common` pre-processing library."
8466+ )
8467+ raise ValueError (err_message )
84478468
84488469 template_path = templates_dir / template_file
84498470 if not template_path .exists ():
@@ -8638,6 +8659,13 @@ def parse_args() -> argparse.Namespace:
86388659 "--mistral-format" , action = "store_true" ,
86398660 help = "Whether the model is stored following the Mistral format." ,
86408661 )
8662+ parser .add_argument (
8663+ "--disable-mistral-community-chat-template" , action = "store_true" ,
8664+ help = (
8665+ "Whether to disable usage of Mistral community chat templates. If set, use the Mistral official `mistral-common` library for tokenization and detokenization of Mistral models. "
8666+ "Using `mistral-common` ensure correctness and zero-day support of tokenization for models converted from the Mistral format but requires to manually setup the tokenization server."
8667+ )
8668+ )
86418669
86428670 args = parser .parse_args ()
86438671 if not args .print_supported_models and args .model is None :
@@ -8744,6 +8772,7 @@ def main() -> None:
87448772 fname_out = ModelBase .add_prefix_to_filename (fname_out , "mmproj-" )
87458773
87468774 is_mistral_format = args .mistral_format
8775+ disable_mistral_community_chat_template = args .disable_mistral_community_chat_template
87478776
87488777 with torch .inference_mode ():
87498778 output_type = ftype_map [args .outtype ]
@@ -8770,7 +8799,7 @@ def main() -> None:
87708799 split_max_tensors = args .split_max_tensors ,
87718800 split_max_size = split_str_to_n_bytes (args .split_max_size ), dry_run = args .dry_run ,
87728801 small_first_shard = args .no_tensor_first_split ,
8773- remote_hf_model_id = hf_repo_id ,
8802+ remote_hf_model_id = hf_repo_id , disable_mistral_community_chat_template = disable_mistral_community_chat_template
87748803 )
87758804
87768805 if args .vocab_only :
0 commit comments