@@ -426,6 +426,8 @@ std::string to_string(ModelType model_type) {
426426 return " Baichuan7B" ;
427427 case MODEL_TYPE_BAICHUAN13B:
428428 return " Baichuan13B" ;
429+ case MODEL_TYPE_INTERNLM:
430+ return " InternLM" ;
429431 default :
430432 CHATGLM_THROW << " unknown model type " << model_type;
431433 }
@@ -1165,6 +1167,174 @@ void Baichuan13BForCausalLM::load(ModelLoader &loader) {
11651167 ctx_.init_device_context ();
11661168}
11671169
1170+ // ===== InternLM =====
1171+
1172+ InternLMTokenizer::InternLMTokenizer (std::string_view serialized_model_proto) {
1173+ const auto status = sp.LoadFromSerializedProto (serialized_model_proto);
1174+ CHATGLM_CHECK (status.ok ()) << status.ToString ();
1175+ }
1176+
1177+ std::vector<int > InternLMTokenizer::encode (const std::string &text, int max_length) const {
1178+ std::vector<int > ids;
1179+ sp.Encode (text, &ids);
1180+ ids.insert (ids.begin (), {bos_token_id}); // special prefix
1181+ if ((int )ids.size () > max_length) {
1182+ // sliding window: drop the least recent history while keeping the special prefix
1183+ int num_drop = (int )ids.size () - max_length;
1184+ ids.erase (ids.begin () + 1 , ids.begin () + 1 + num_drop);
1185+ }
1186+ return ids;
1187+ }
1188+
1189+ std::string InternLMTokenizer::decode (const std::vector<int > &ids) const {
1190+ // filter out special tokens
1191+ std::vector<int > normal_ids (ids);
1192+ normal_ids.erase (std::remove_if (normal_ids.begin (), normal_ids.end (), [this ](int id) { return is_special_id (id); }),
1193+ normal_ids.end ());
1194+
1195+ std::string text;
1196+ sp.Decode (normal_ids, &text);
1197+ // remove <eoa> and its following
1198+ size_t eoa_pos = text.find (" <eoa>" );
1199+ if (eoa_pos != std::string::npos) {
1200+ text.erase (eoa_pos);
1201+ }
1202+ return text;
1203+ }
1204+
1205+ std::vector<int > InternLMTokenizer::encode_history (const std::vector<std::string> &history, int max_length) const {
1206+ std::string prompt = build_prompt (history);
1207+ std::vector<int > input_ids = encode (prompt, max_length);
1208+ return input_ids;
1209+ }
1210+
1211+ std::string InternLMTokenizer::build_prompt (const std::vector<std::string> &history) {
1212+ CHATGLM_CHECK (history.size () % 2 == 1 ) << " invalid history size " << history.size ();
1213+
1214+ std::ostringstream oss_prompt;
1215+ for (size_t i = 0 ; i < history.size (); i += 2 ) {
1216+ oss_prompt << " <|User|>:" << history[i] << " <eoh>\n <|Bot|>:" ;
1217+ if (i < history.size () - 1 ) {
1218+ oss_prompt << history[i + 1 ] << " <eoa>\n " ;
1219+ }
1220+ }
1221+ return oss_prompt.str ();
1222+ }
1223+
1224+ InternLM7BForCausalLM::InternLM7BForCausalLM (const ModelConfig &config)
1225+ : BasicModelForCausalLM(MODEL_TYPE_INTERNLM, config, MEM_SIZE, SCRATCH_SIZE) {
1226+ constexpr size_t tensor_ovhd = GGML_TENSOR_SIZE + GGML_OBJECT_SIZE;
1227+ const size_t num_weights = 3 + config.num_hidden_layers * 9 ;
1228+ const size_t ctx_w_size = num_weights * tensor_ovhd;
1229+ const size_t ctx_kv_size = 2 * config.num_hidden_layers *
1230+ (config.max_length * config.hidden_size * ggml_type_size (GGML_TYPE_F16) + tensor_ovhd);
1231+ ctx_.dtype = config.dtype ;
1232+ ctx_.ctx_w = make_unique_ggml_context (ctx_w_size, nullptr , true );
1233+ ctx_.ctx_kv = make_unique_ggml_context (ctx_kv_size + 1 * MB, nullptr , false ); // 1MB extra for MPS
1234+
1235+ transformer = InternLM7BModel (&ctx_, config);
1236+ lm_head = Linear (&ctx_, config.hidden_size , config.vocab_size , false );
1237+ CHATGLM_CHECK (ggml_used_mem (ctx_.ctx_w .get ()) == ggml_get_mem_size (ctx_.ctx_w .get ())) << " corrupted model weights" ;
1238+ CHATGLM_CHECK (ggml_used_mem (ctx_.ctx_kv .get ()) == ctx_kv_size) << " corrupted kv cache" ;
1239+
1240+ // build state_dict
1241+ state_dict_.reserve (num_weights);
1242+ state_dict_.emplace_back (" model.embed_tokens.weight" , transformer.word_embeddings .weight );
1243+ for (int i = 0 ; i < config.num_hidden_layers ; i++) {
1244+ std::string layer_prefix = " model.layers." + std::to_string (i) + ' .' ;
1245+ state_dict_.emplace_back (layer_prefix + " input_layernorm.weight" , transformer.layers [i].input_layernorm .weight );
1246+ state_dict_.emplace_back (layer_prefix + " self_attn.qkv_proj.weight" ,
1247+ transformer.layers [i].attention .query_key_value .weight );
1248+ if (transformer.layers [i].attention .query_key_value .bias ) {
1249+ state_dict_.emplace_back (layer_prefix + " self_attn.qkv_proj.bias" ,
1250+ transformer.layers [i].attention .query_key_value .bias );
1251+ }
1252+ state_dict_.emplace_back (layer_prefix + " self_attn.o_proj.weight" ,
1253+ transformer.layers [i].attention .dense .weight );
1254+ if (transformer.layers [i].attention .dense .bias ) {
1255+ state_dict_.emplace_back (layer_prefix + " self_attn.o_proj.bias" ,
1256+ transformer.layers [i].attention .dense .bias );
1257+ }
1258+ state_dict_.emplace_back (layer_prefix + " post_attention_layernorm.weight" ,
1259+ transformer.layers [i].post_attention_layernorm .weight );
1260+ state_dict_.emplace_back (layer_prefix + " mlp.gate_proj.weight" , transformer.layers [i].mlp .gate_proj .weight );
1261+ state_dict_.emplace_back (layer_prefix + " mlp.up_proj.weight" , transformer.layers [i].mlp .up_proj .weight );
1262+ state_dict_.emplace_back (layer_prefix + " mlp.down_proj.weight" , transformer.layers [i].mlp .down_proj .weight );
1263+ }
1264+ state_dict_.emplace_back (" model.norm.weight" , transformer.final_layernorm .weight );
1265+ state_dict_.emplace_back (" lm_head.weight" , lm_head.weight );
1266+ }
1267+
1268+ void InternLM7BForCausalLM::load (ModelLoader &loader) {
1269+ for (auto &item : state_dict_) {
1270+ const std::string &name = item.first ;
1271+ ggml_tensor *tensor = item.second ;
1272+ loader.read_tensor (name, tensor);
1273+ }
1274+
1275+ to_device (" model.embed_tokens.weight" );
1276+
1277+ ctx_.weight_buffer = std::string_view (loader.data , loader.size );
1278+ ctx_.init_device_context ();
1279+ }
1280+
1281+ InternLM20BForCausalLM::InternLM20BForCausalLM (const ModelConfig &config)
1282+ : BasicModelForCausalLM(MODEL_TYPE_INTERNLM, config, MEM_SIZE, SCRATCH_SIZE) {
1283+ constexpr size_t tensor_ovhd = GGML_TENSOR_SIZE + GGML_OBJECT_SIZE;
1284+ const size_t num_weights = 3 + config.num_hidden_layers * 7 ;
1285+ const size_t ctx_w_size = num_weights * tensor_ovhd;
1286+ const size_t ctx_kv_size = 2 * config.num_hidden_layers *
1287+ (config.max_length * config.hidden_size * ggml_type_size (GGML_TYPE_F16) + tensor_ovhd);
1288+ ctx_.dtype = config.dtype ;
1289+ ctx_.ctx_w = make_unique_ggml_context (ctx_w_size, nullptr , true );
1290+ ctx_.ctx_kv = make_unique_ggml_context (ctx_kv_size + 1 * MB, nullptr , false ); // 1MB extra for MPS
1291+
1292+ transformer = InternLM20BModel (&ctx_, config);
1293+ lm_head = Linear (&ctx_, config.hidden_size , config.vocab_size , false );
1294+ CHATGLM_CHECK (ggml_used_mem (ctx_.ctx_w .get ()) == ggml_get_mem_size (ctx_.ctx_w .get ())) << " corrupted model weights" ;
1295+ CHATGLM_CHECK (ggml_used_mem (ctx_.ctx_kv .get ()) == ctx_kv_size) << " corrupted kv cache" ;
1296+
1297+ // build state_dict
1298+ state_dict_.reserve (num_weights);
1299+ state_dict_.emplace_back (" model.embed_tokens.weight" , transformer.word_embeddings .weight );
1300+ for (int i = 0 ; i < config.num_hidden_layers ; i++) {
1301+ std::string layer_prefix = " model.layers." + std::to_string (i) + ' .' ;
1302+ state_dict_.emplace_back (layer_prefix + " input_layernorm.weight" , transformer.layers [i].input_layernorm .weight );
1303+ state_dict_.emplace_back (layer_prefix + " self_attn.qkv_proj.weight" ,
1304+ transformer.layers [i].attention .query_key_value .weight );
1305+ if (transformer.layers [i].attention .query_key_value .bias ) {
1306+ state_dict_.emplace_back (layer_prefix + " self_attn.qkv_proj.bias" ,
1307+ transformer.layers [i].attention .query_key_value .bias );
1308+ }
1309+ state_dict_.emplace_back (layer_prefix + " self_attn.o_proj.weight" ,
1310+ transformer.layers [i].attention .dense .weight );
1311+ if (transformer.layers [i].attention .dense .bias ) {
1312+ state_dict_.emplace_back (layer_prefix + " self_attn.o_proj.bias" ,
1313+ transformer.layers [i].attention .dense .bias );
1314+ }
1315+ state_dict_.emplace_back (layer_prefix + " post_attention_layernorm.weight" ,
1316+ transformer.layers [i].post_attention_layernorm .weight );
1317+ state_dict_.emplace_back (layer_prefix + " mlp.gate_proj.weight" , transformer.layers [i].mlp .gate_proj .weight );
1318+ state_dict_.emplace_back (layer_prefix + " mlp.up_proj.weight" , transformer.layers [i].mlp .up_proj .weight );
1319+ state_dict_.emplace_back (layer_prefix + " mlp.down_proj.weight" , transformer.layers [i].mlp .down_proj .weight );
1320+ }
1321+ state_dict_.emplace_back (" model.norm.weight" , transformer.final_layernorm .weight );
1322+ state_dict_.emplace_back (" lm_head.weight" , lm_head.weight );
1323+ }
1324+
1325+ void InternLM20BForCausalLM::load (ModelLoader &loader) {
1326+ for (auto &item : state_dict_) {
1327+ const std::string &name = item.first ;
1328+ ggml_tensor *tensor = item.second ;
1329+ loader.read_tensor (name, tensor);
1330+ }
1331+
1332+ to_device (" model.embed_tokens.weight" );
1333+
1334+ ctx_.weight_buffer = std::string_view (loader.data , loader.size );
1335+ ctx_.init_device_context ();
1336+ }
1337+
11681338// ===== pipeline =====
11691339
11701340Pipeline::Pipeline (const std::string &path) {
@@ -1241,6 +1411,26 @@ Pipeline::Pipeline(const std::string &path) {
12411411 // load model
12421412 model = std::make_unique<Baichuan13BForCausalLM>(config);
12431413 model->load (loader);
1414+ } else if (model_type == MODEL_TYPE_INTERNLM) {
1415+ CHATGLM_CHECK (version == 1 ) << " only support version 1 for now but got " << version;
1416+
1417+ // load config
1418+ ModelConfig config (loader.read_basic <ConfigRecordV1>());
1419+ config.norm_eps = 1e-6 ;
1420+
1421+ // load tokenizer
1422+ int proto_size = loader.read_basic <int >();
1423+ std::string_view serialized_model_proto ((char *)mapped_file->data + loader.tell (), proto_size);
1424+ loader.seek (proto_size, SEEK_CUR);
1425+ tokenizer = std::make_unique<InternLMTokenizer>(serialized_model_proto);
1426+
1427+ // load model
1428+ if (config.hidden_size == 4096 ) {
1429+ model = std::make_unique<InternLM7BForCausalLM>(config);
1430+ } else {
1431+ model = std::make_unique<InternLM20BForCausalLM>(config);
1432+ }
1433+ model->load (loader);
12441434 } else {
12451435 CHATGLM_THROW << " invalid model type " << model_type;
12461436 }
0 commit comments