Skip to content

Commit f5eee7d

Browse files
authored
Add Mistral AI instruction template (#2483)
1 parent 46e5207 commit f5eee7d

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

fastchat/conversation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,19 @@ def get_conv_template(name: str) -> Conversation:
840840
)
841841
)
842842

843+
# Mistral template
844+
# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template
845+
register_conv_template(
846+
Conversation(
847+
name="mistral",
848+
system_template="",
849+
roles=("[INST] ", " [/INST]"),
850+
sep_style=SeparatorStyle.LLAMA2,
851+
sep="",
852+
sep2=" </s>",
853+
)
854+
)
855+
843856
# llama2 template
844857
# reference: https://huggingface.co/blog/codellama#conversational-instructions
845858
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212

fastchat/model/model_adapter.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,22 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
12561256
return get_conv_template("starchat")
12571257

12581258

1259+
class MistralAdapter(BaseModelAdapter):
1260+
"""The model adapter for Mistral AI models"""
1261+
1262+
def match(self, model_path: str):
1263+
return "mistral" in model_path.lower()
1264+
1265+
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
1266+
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
1267+
model.config.eos_token_id = tokenizer.eos_token_id
1268+
model.config.pad_token_id = tokenizer.pad_token_id
1269+
return model, tokenizer
1270+
1271+
def get_default_conv_template(self, model_path: str) -> Conversation:
1272+
return get_conv_template("mistral")
1273+
1274+
12591275
class Llama2Adapter(BaseModelAdapter):
12601276
"""The model adapter for Llama-2 (e.g., meta-llama/Llama-2-7b-hf)"""
12611277

@@ -1653,6 +1669,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
16531669
register_model_adapter(InternLMChatAdapter)
16541670
register_model_adapter(StarChatAdapter)
16551671
register_model_adapter(Llama2Adapter)
1672+
register_model_adapter(MistralAdapter)
16561673
register_model_adapter(CuteGPTAdapter)
16571674
register_model_adapter(OpenOrcaAdapter)
16581675
register_model_adapter(WizardCoderAdapter)

0 commit comments

Comments
 (0)