@@ -1256,6 +1256,22 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
12561256 return get_conv_template ("starchat" )
12571257
12581258
1259+ class MistralAdapter (BaseModelAdapter ):
1260+ """The model adapter for Mistral AI models"""
1261+
1262+ def match (self , model_path : str ):
1263+ return "mistral" in model_path .lower ()
1264+
1265+ def load_model (self , model_path : str , from_pretrained_kwargs : dict ):
1266+ model , tokenizer = super ().load_model (model_path , from_pretrained_kwargs )
1267+ model .config .eos_token_id = tokenizer .eos_token_id
1268+ model .config .pad_token_id = tokenizer .pad_token_id
1269+ return model , tokenizer
1270+
1271+ def get_default_conv_template (self , model_path : str ) -> Conversation :
1272+ return get_conv_template ("mistral" )
1273+
1274+
12591275class Llama2Adapter (BaseModelAdapter ):
12601276 """The model adapter for Llama-2 (e.g., meta-llama/Llama-2-7b-hf)"""
12611277
@@ -1653,6 +1669,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
16531669register_model_adapter (InternLMChatAdapter )
16541670register_model_adapter (StarChatAdapter )
16551671register_model_adapter (Llama2Adapter )
1672+ register_model_adapter (MistralAdapter )
16561673register_model_adapter (CuteGPTAdapter )
16571674register_model_adapter (OpenOrcaAdapter )
16581675register_model_adapter (WizardCoderAdapter )
0 commit comments