@@ -455,6 +455,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
455455 GGML_ASSERT(hparams.n_expert_used == 0);
456456 }
457457
458+ // multi-token predict
459+ ml.get_key(LLM_KV_N_MULTI_TOKEN_PREDICT, hparams.n_mtp, false);
460+
458461 // zero-out the array hparams
459462 std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
460463 std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
@@ -4323,6 +4326,10 @@ void llama_model::print_info() const {
43234326 LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
43244327 }
43254328
4329+ if (hparams.n_mtp) {
4330+ LLAMA_LOG_INFO("%s: n_mtp = %u\n", __func__, hparams.n_mtp);
4331+ }
4332+
43264333 LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
43274334 if (pimpl->n_elements >= 1e12) {
43284335 LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12);
@@ -13234,6 +13241,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
1323413241 return model->hparams.n_head_kv();
1323513242}
1323613243
13244+ int32_t llama_model_n_mtp(const llama_model * model) {
13245+ return model->hparams.n_mtp;
13246+ }
13247+
1323713248// deprecated
1323813249int32_t llama_n_ctx_train(const llama_model * model) {
1323913250 return llama_model_n_ctx_train(model);
0 commit comments