Skip to content

Commit e7b29e1

Browse files
committed
use code from ik
1 parent 5d67144 commit e7b29e1

File tree

1 file changed

+43
-57
lines changed

1 file changed

+43
-57
lines changed

llama.cpp/llama.cpp

Lines changed: 43 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10904,7 +10904,8 @@ struct llm_build_context {
1090410904
// inp_pos - contains the positions
1090510905
ggml_tensor * inp_pos = build_inp_pos();
1090610906

10907-
ggml_tensor * inp_attn = build_inp_KQ_mask();
10907+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
10908+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
1090810909

1090910910
for (int il = 0; il < n_layer; ++il) {
1091010911
ggml_tensor * inpSA = inpL;
@@ -10927,49 +10928,41 @@ struct llm_build_context {
1092710928
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
1092810929
cb(Vcur, "Vcur", il);
1092910930

10930-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
10931-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
10932-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
10933-
10934-
Qcur = llm_build_norm(ctx0, Qcur, hparams,
10935-
model.layers[il].attn_q_norm, NULL,
10936-
LLM_NORM_RMS, cb, il);
10931+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
10932+
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
1093710933
cb(Qcur, "Qcur_normed", il);
1093810934

1093910935
Qcur = ggml_rope_ext(
10940-
ctx0, Qcur, inp_pos, nullptr,
10941-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10942-
ext_factor, attn_factor, beta_fast, beta_slow
10943-
);
10936+
ctx0, Qcur, inp_pos, nullptr,
10937+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10938+
ext_factor, attn_factor, beta_fast, beta_slow
10939+
);
10940+
cb(Qcur, "Qcur", il);
1094410941

10945-
Kcur = llm_build_norm(ctx0, Kcur, hparams,
10946-
model.layers[il].attn_k_norm, NULL,
10947-
LLM_NORM_RMS, cb, il);
10942+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
10943+
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
1094810944
cb(Kcur, "Kcur_normed", il);
1094910945

1095010946
Kcur = ggml_rope_ext(
10951-
ctx0, Kcur, inp_pos, nullptr,
10952-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10953-
ext_factor, attn_factor, beta_fast, beta_slow
10954-
);
10955-
10956-
cb(Qcur, "Qcur", il);
10947+
ctx0, Kcur, inp_pos, nullptr,
10948+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
10949+
ext_factor, attn_factor, beta_fast, beta_slow
10950+
);
1095710951
cb(Kcur, "Kcur", il);
10958-
cb(Vcur, "Vcur", il);
1095910952

1096010953
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
1096110954
model.layers[il].wo, model.layers[il].bo,
10962-
Kcur, Vcur, Qcur, inp_attn, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
10955+
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
1096310956
}
1096410957

1096510958
if (il == n_layer - 1) {
1096610959
// skip computing output for unused tokens
10967-
ggml_tensor * inp_out_ids = build_inp_out_ids();
10960+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
1096810961
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1096910962
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
1097010963
}
1097110964

10972-
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
10965+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
1097310966
cb(ffn_inp, "ffn_inp", il);
1097410967

1097510968
// feed-forward network
@@ -11016,8 +11009,10 @@ struct llm_build_context {
1101611009
struct ggml_cgraph * build_qwen3moe() {
1101711010
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
1101811011

11019-
const int64_t n_embd_head = hparams.n_embd_head_v;
11012+
// mutable variable, needed during the last layer of the computation to skip unused tokens
11013+
int32_t n_tokens = this->n_tokens;
1102011014

11015+
const int64_t n_embd_head = hparams.n_embd_head_v;
1102111016
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
1102211017
GGML_ASSERT(n_embd_head == hparams.n_rot);
1102311018

@@ -11029,7 +11024,8 @@ struct llm_build_context {
1102911024
// inp_pos - contains the positions
1103011025
ggml_tensor * inp_pos = build_inp_pos();
1103111026

11032-
ggml_tensor * inp_attn = build_inp_KQ_mask();
11027+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
11028+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
1103311029

1103411030
for (int il = 0; il < n_layer; ++il) {
1103511031
ggml_tensor * inpSA = inpL;
@@ -11052,45 +11048,37 @@ struct llm_build_context {
1105211048
ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
1105311049
cb(Vcur, "Vcur", il);
1105411050

11055-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
11056-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
11057-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
11058-
11059-
Qcur = llm_build_norm(ctx0, Qcur, hparams,
11060-
model.layers[il].attn_q_norm, NULL,
11061-
LLM_NORM_RMS, cb, il);
11051+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
11052+
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il);
1106211053
cb(Qcur, "Qcur_normed", il);
1106311054

1106411055
Qcur = ggml_rope_ext(
11065-
ctx0, Qcur, inp_pos, nullptr,
11066-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11067-
ext_factor, attn_factor, beta_fast, beta_slow
11068-
);
11056+
ctx0, Qcur, inp_pos, nullptr,
11057+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11058+
ext_factor, attn_factor, beta_fast, beta_slow
11059+
);
11060+
cb(Qcur, "Qcur", il);
1106911061

11070-
Kcur = llm_build_norm(ctx0, Kcur, hparams,
11071-
model.layers[il].attn_k_norm, NULL,
11072-
LLM_NORM_RMS, cb, il);
11062+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
11063+
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
1107311064
cb(Kcur, "Kcur_normed", il);
1107411065

1107511066
Kcur = ggml_rope_ext(
11076-
ctx0, Kcur, inp_pos, nullptr,
11077-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11078-
ext_factor, attn_factor, beta_fast, beta_slow
11079-
);
11080-
11081-
cb(Qcur, "Qcur", il);
11067+
ctx0, Kcur, inp_pos, nullptr,
11068+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
11069+
ext_factor, attn_factor, beta_fast, beta_slow
11070+
);
1108211071
cb(Kcur, "Kcur", il);
11083-
cb(Vcur, "Vcur", il);
1108411072

11085-
// inp_attn
1108611073
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
1108711074
model.layers[il].wo, model.layers[il].bo,
11088-
Kcur, Vcur, Qcur, inp_attn, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
11075+
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
1108911076
}
1109011077

1109111078
if (il == n_layer - 1) {
1109211079
// skip computing output for unused tokens
11093-
ggml_tensor * inp_out_ids = build_inp_out_ids();
11080+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
11081+
n_tokens = n_outputs;
1109411082
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1109511083
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
1109611084
}
@@ -11104,19 +11092,17 @@ struct llm_build_context {
1110411092
LLM_NORM_RMS, cb, il);
1110511093
cb(cur, "ffn_norm", il);
1110611094

11107-
ggml_tensor * moe_out =
11108-
llm_build_moe_ffn(ctx0, lctx, cur,
11095+
cur =
11096+
llm_build_moe_ffn(ctx0, lctx, cur,
1110911097
model.layers[il].ffn_gate_inp,
1111011098
model.layers[il].ffn_up_exps,
1111111099
model.layers[il].ffn_gate_exps,
1111211100
model.layers[il].ffn_down_exps,
1111311101
n_expert, n_expert_used,
1111411102
LLM_FFN_SILU, true,
1111511103
false, 0.0,
11116-
cb,
11117-
il);
11118-
cb(moe_out, "ffn_moe_out", il);
11119-
cur = moe_out;
11104+
cb, il);
11105+
cb(cur, "ffn_moe_out", il);
1112011106

1112111107
cur = ggml_add(ctx0, cur, ffn_inp);
1112211108

0 commit comments

Comments
 (0)