Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,16 @@ class ModelBase:
block_count: int
tensor_map: gguf.TensorNameMap

# Mistral format specifics
is_mistral_format: bool = False
use_mistral_community_chat_template: bool = False

def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
use_mistral_community_chat_template: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
Expand Down Expand Up @@ -147,6 +150,9 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)

# Mistral specific
self.use_mistral_community_chat_template = use_mistral_community_chat_template

@classmethod
def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path:
stem, suffix = path.stem, path.suffix
Expand Down Expand Up @@ -2011,8 +2017,17 @@ def _set_vocab_mistral(self):

template_dir = Path(__file__).parent / "models/templates/"

template = MistralModel.get_community_chat_template(vocab, template_dir)
self.gguf_writer.add_chat_template(template)
if not self.is_mistral_format or self.use_mistral_community_chat_template:
# Log only for Mistral format that the official tokenization and detokenization is via `mistral-common`.
if self.is_mistral_format:
logger.info(
"Using a Mistral community chat template. These templates are subject to errors, especially in early days or weeks after a release. "
"The official way of using Mistral models is via `mistral-common`."
)
template = MistralModel.get_community_chat_template(vocab, template_dir)
self.gguf_writer.add_chat_template(template)
else:
logger.info("Not using a Mistral community chat template. Ensure to follow the official tokenization and detokenization process via `mistral-common`.")

def set_vocab(self):
if self.is_mistral_format:
Expand Down Expand Up @@ -8638,6 +8653,13 @@ def parse_args() -> argparse.Namespace:
"--mistral-format", action="store_true",
help="Whether the model is stored following the Mistral format.",
)
parser.add_argument(
"--use-mistral-community-chat-template", action="store_true",
help=(
"Whether to store in the GGUF file a Mistral community chat template for the Mistral format. These are not official templates that may contains errors, "
"especially in the first days or weeks after a release. The official process of tokenization and detokenization using Mistral models is via `mistral-common`."
)
)

args = parser.parse_args()
if not args.print_supported_models and args.model is None:
Expand Down Expand Up @@ -8744,6 +8766,7 @@ def main() -> None:
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")

is_mistral_format = args.mistral_format
use_mistral_community_chat_template = args.use_mistral_community_chat_template

with torch.inference_mode():
output_type = ftype_map[args.outtype]
Expand All @@ -8770,7 +8793,7 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id,
remote_hf_model_id=hf_repo_id, use_mistral_community_chat_template=use_mistral_community_chat_template
)

if args.vocab_only:
Expand Down