Skip to content

Commit b2caf67

Browse files
authored
convert : make Mistral community chat templates optional via parameter (ggml-org#15420)
* Make Mistral community chat templates optional * Change the flag arg to disable instead of enable community chat templates * Improve error message * Improve help message * Tone down the logger messages
1 parent 2f3dbff commit b2caf67

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

convert_hf_to_gguf.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,16 @@ class ModelBase:
8989
block_count: int
9090
tensor_map: gguf.TensorNameMap
9191

92+
# Mistral format specifics
9293
is_mistral_format: bool = False
94+
disable_mistral_community_chat_template: bool = False
9395

9496
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
9597
use_temp_file: bool = False, eager: bool = False,
9698
metadata_override: Path | None = None, model_name: str | None = None,
9799
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
98-
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
100+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
101+
disable_mistral_community_chat_template: bool = False):
99102
if type(self) is ModelBase or \
100103
type(self) is TextModel or \
101104
type(self) is MmprojModel:
@@ -147,6 +150,9 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
147150
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
148151
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
149152

153+
# Mistral specific
154+
self.disable_mistral_community_chat_template = disable_mistral_community_chat_template
155+
150156
@classmethod
151157
def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path:
152158
stem, suffix = path.stem, path.suffix
@@ -2011,8 +2017,17 @@ def _set_vocab_mistral(self):
20112017

20122018
template_dir = Path(__file__).parent / "models/templates/"
20132019

2014-
template = MistralModel.get_community_chat_template(vocab, template_dir)
2015-
self.gguf_writer.add_chat_template(template)
2020+
if not self.is_mistral_format or not self.disable_mistral_community_chat_template:
2021+
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
2022+
if self.is_mistral_format:
2023+
logger.info(
2024+
"Using a Mistral community chat template. These templates can be subject to errors in early days or weeks after a release. "
2025+
"Mistral recommends to use `mistral-common` to perform tokenization and detokenization."
2026+
)
2027+
template = MistralModel.get_community_chat_template(vocab, template_dir, self.is_mistral_format)
2028+
self.gguf_writer.add_chat_template(template)
2029+
else:
2030+
logger.info("Not using a Mistral community chat template. Ensure to perform the tokenization and detokenization via `mistral-common`.")
20162031

20172032
def set_vocab(self):
20182033
if self.is_mistral_format:
@@ -8422,7 +8437,7 @@ class MistralModel(LlamaModel):
84228437
undo_permute = False
84238438

84248439
@staticmethod
8425-
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path):
8440+
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
84268441
assert TokenizerVersion is not None, "mistral_common is not installed"
84278442
assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
84288443
f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
@@ -8443,7 +8458,13 @@ def get_community_chat_template(vocab: MistralVocab, templates_dir: Path):
84438458
elif vocab.tokenizer.version == TokenizerVersion.v13:
84448459
template_file = "unsloth-mistral-Devstral-Small-2507.jinja"
84458460
else:
8446-
raise ValueError(f"Unknown tokenizer type: {vocab.tokenizer_type} and version {vocab.tokenizer.version}")
8461+
err_message = f"Unknown tokenizer type: {vocab.tokenizer_type} and version {vocab.tokenizer.version}"
8462+
if is_mistral_format:
8463+
err_message += (
8464+
" . Please pass --disable-mistral-community-chat-template argument to the CLI "
8465+
"if you want to skip this error and use the Mistral official `mistral-common` pre-processing library."
8466+
)
8467+
raise ValueError(err_message)
84478468

84488469
template_path = templates_dir / template_file
84498470
if not template_path.exists():
@@ -8638,6 +8659,13 @@ def parse_args() -> argparse.Namespace:
86388659
"--mistral-format", action="store_true",
86398660
help="Whether the model is stored following the Mistral format.",
86408661
)
8662+
parser.add_argument(
8663+
"--disable-mistral-community-chat-template", action="store_true",
8664+
help=(
8665+
"Whether to disable usage of Mistral community chat templates. If set, use the Mistral official `mistral-common` library for tokenization and detokenization of Mistral models. "
8666+
"Using `mistral-common` ensure correctness and zero-day support of tokenization for models converted from the Mistral format but requires to manually setup the tokenization server."
8667+
)
8668+
)
86418669

86428670
args = parser.parse_args()
86438671
if not args.print_supported_models and args.model is None:
@@ -8744,6 +8772,7 @@ def main() -> None:
87448772
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
87458773

87468774
is_mistral_format = args.mistral_format
8775+
disable_mistral_community_chat_template = args.disable_mistral_community_chat_template
87478776

87488777
with torch.inference_mode():
87498778
output_type = ftype_map[args.outtype]
@@ -8770,7 +8799,7 @@ def main() -> None:
87708799
split_max_tensors=args.split_max_tensors,
87718800
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
87728801
small_first_shard=args.no_tensor_first_split,
8773-
remote_hf_model_id=hf_repo_id,
8802+
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
87748803
)
87758804

87768805
if args.vocab_only:

0 commit comments

Comments
 (0)