@@ -88,6 +88,7 @@ class ModelType:
8888 internlm_7b_chat_8k = 'internlm-7b-chat-8k'
8989 internlm_20b = 'internlm-20b'
9090 internlm_20b_chat = 'internlm-20b-chat'
91+ # internlm2
9192 internlm2_7b_base = 'internlm2-7b-base'
9293 internlm2_7b = 'internlm2-7b'
9394 internlm2_7b_sft_chat = 'internlm2-7b-sft-chat'
@@ -96,6 +97,11 @@ class ModelType:
9697 internlm2_20b = 'internlm2-20b'
9798 internlm2_20b_sft_chat = 'internlm2-20b-sft-chat'
9899 internlm2_20b_chat = 'internlm2-20b-chat'
100+ # internlm2-math
101+ internlm2_math_7b_chat = 'internlm2-math-7b-chat'
102+ internlm2_math_7b = 'internlm2-math-7b'
103+ internlm2_math_20b_chat = 'internlm2-math-20b-chat'
104+ internlm2_math_20b = 'internlm2-math-20b'
99105 # deepseek
100106 deepseek_7b = 'deepseek-7b'
101107 deepseek_7b_chat = 'deepseek-7b-chat'
@@ -120,6 +126,7 @@ class ModelType:
120126 baichuan_7b = 'baichuan-7b'
121127 baichuan_13b = 'baichuan-13b'
122128 baichuan_13b_chat = 'baichuan-13b-chat'
129+ # baichuan2
123130 baichuan2_7b = 'baichuan2-7b'
124131 baichuan2_7b_chat = 'baichuan2-7b-chat'
125132 baichuan2_7b_chat_int4 = 'baichuan2-7b-chat-int4'
@@ -911,6 +918,32 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
911918 load_model , model_config , ** kwargs )
912919
913920
921+ @register_model (
922+ ModelType .internlm2_math_7b ,
923+ 'Shanghai_AI_Laboratory/internlm2-math-base-7b' ,
924+ LoRATM .internlm2 ,
925+ TemplateType .default_generation_bos ,
926+ support_flash_attn = True )
927+ @register_model (
928+ ModelType .internlm2_math_20b ,
929+ 'Shanghai_AI_Laboratory/internlm2-math-base-20b' ,
930+ LoRATM .internlm2 ,
931+ TemplateType .default_generation_bos ,
932+ support_flash_attn = True )
933+ @register_model (
934+ ModelType .internlm2_math_7b_chat ,
935+ 'Shanghai_AI_Laboratory/internlm2-math-7b' ,
936+ LoRATM .internlm2 ,
937+ TemplateType .internlm2 ,
938+ eos_token = '<|im_end|>' ,
939+ support_flash_attn = True )
940+ @register_model (
941+ ModelType .internlm2_math_20b_chat ,
942+ 'Shanghai_AI_Laboratory/internlm2-math-20b' ,
943+ LoRATM .internlm2 ,
944+ TemplateType .internlm2 ,
945+ eos_token = '<|im_end|>' ,
946+ support_flash_attn = True )
914947@register_model (
915948 ModelType .internlm2_7b_sft_chat ,
916949 'Shanghai_AI_Laboratory/internlm2-chat-7b-sft' ,
@@ -986,9 +1019,7 @@ def get_model_tokenizer_internlm2(model_dir: str,
9861019 if getattr (tokenizer .__class__ .eos_token_id , 'fset' , None ) is None :
9871020 del tokenizer .__class__ .eos_token_id
9881021 tokenizer .eos_token = eos_token
989- if model is not None and use_flash_attn :
990- # fix AttributeError: no attribute 'attention_dropout'
991- model .model .layers [0 ].attention .__class__ .attention_dropout = 0.
1022+
9921023 return model , tokenizer
9931024
9941025
0 commit comments