Skip to content

Commit e8b115b

Browse files
committed
try add set_mpt_head api
1 parent 2745373 commit e8b115b

File tree

8 files changed

+48
-0
lines changed

8 files changed

+48
-0
lines changed

include/llama.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,12 @@ extern "C" {
496496
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
497497
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
498498

499+
// If model supports multi-token predict, this returns number of tokens ; returns 0 otherwise
500+
LLAMA_API int32_t llama_model_n_mtp(const struct llama_model * model);
501+
502+
// Get the i-th multi-token predict model (used by speculative decoding)
503+
LLAMA_API struct llama_model * llama_model_get_mtp(struct llama_model * model, int32_t i);
504+
499505
// Get the model's RoPE frequency scaling factor
500506
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
501507

@@ -959,6 +965,9 @@ extern "C" {
959965
// If set to true, the model will only attend to the past tokens
960966
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
961967

968+
// Set whether to use multi-token predict head ; 0 means no MTP
969+
LLAMA_API void llama_set_mpt_head(struct llama_context * ctx, int32_t n_mtp);
970+
962971
// Set whether the model is in warmup mode or not
963972
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
964973
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
121121
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
122122
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
123123
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
124+
{ LLM_KV_N_MULTI_TOKEN_PREDICT, "%s.n_multi_token_predict" },
124125

125126
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
126127
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ enum llm_kv {
125125
LLM_KV_EMBEDDING_SCALE,
126126
LLM_KV_TOKEN_SHIFT_COUNT,
127127
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
128+
LLM_KV_N_MULTI_TOKEN_PREDICT,
128129

129130
LLM_KV_ATTENTION_HEAD_COUNT,
130131
LLM_KV_ATTENTION_HEAD_COUNT_KV,

src/llama-context.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,18 @@ void llama_context::set_causal_attn(bool value) {
625625
cparams.causal_attn = value;
626626
}
627627

628+
void llama_context::set_causal_attn(bool value) {
629+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
630+
631+
cparams.causal_attn = value;
632+
}
633+
634+
void llama_context::set_mpt_head(int32_t value) {
635+
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
636+
637+
cparams.curr_mtp = value;
638+
}
639+
628640
void llama_context::set_warmup(bool value) {
629641
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
630642

@@ -1981,6 +1993,11 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
19811993
ctx->set_causal_attn(causal_attn);
19821994
}
19831995

1996+
void llama_set_mpt_head(llama_context * ctx, int32_t n_mtp) {
1997+
GGML_ASSERT(n_mtp <= llama_model_n_mtp(llama_get_model(ctx)));
1998+
ctx->set_mpt_head(n_mtp);
1999+
}
2000+
19842001
void llama_set_warmup(llama_context * ctx, bool warmup) {
19852002
ctx->set_warmup(warmup);
19862003
}

src/llama-context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ struct llama_context {
6969

7070
void set_embeddings (bool value);
7171
void set_causal_attn(bool value);
72+
void set_mpt_head(int32_t value);
7273
void set_warmup(bool value);
7374

7475
void set_adapter_lora(

src/llama-cparams.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ struct llama_cparams {
3131
bool no_perf;
3232
bool warmup;
3333

34+
// multi-token predict
35+
// 0 means not using MTP
36+
// N means using the nth MTP head
37+
int32_t curr_mtp = 0;
38+
3439
enum llama_pooling_type pooling_type;
3540

3641
ggml_backend_sched_eval_callback cb_eval;

src/llama-hparams.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct llama_hparams {
4747
uint32_t n_embd_head_k_mla = 0;
4848
uint32_t n_embd_head_v_mla = 0;
4949

50+
// for multi-token predict
51+
uint32_t n_mtp = 0;
52+
5053
// for WavTokenizer
5154
struct llama_hparams_posnet posnet;
5255
struct llama_hparams_convnext convnext;

src/llama-model.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
455455
GGML_ASSERT(hparams.n_expert_used == 0);
456456
}
457457

458+
// multi-token predict
459+
ml.get_key(LLM_KV_N_MULTI_TOKEN_PREDICT, hparams.n_mtp, false);
460+
458461
// zero-out the array hparams
459462
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
460463
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
@@ -4323,6 +4326,10 @@ void llama_model::print_info() const {
43234326
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
43244327
}
43254328

4329+
if (hparams.n_mtp) {
4330+
LLAMA_LOG_INFO("%s: n_mtp = %u\n", __func__, hparams.n_mtp);
4331+
}
4332+
43264333
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
43274334
if (pimpl->n_elements >= 1e12) {
43284335
LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12);
@@ -13234,6 +13241,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
1323413241
return model->hparams.n_head_kv();
1323513242
}
1323613243

13244+
int32_t llama_model_n_mtp(const llama_model * model) {
13245+
return model->hparams.n_mtp;
13246+
}
13247+
1323713248
// deprecated
1323813249
int32_t llama_n_ctx_train(const llama_model * model) {
1323913250
return llama_model_n_ctx_train(model);

0 commit comments

Comments
 (0)