Skip to content

Commit 7e50c4a

Browse files
committed
Add compute graph handling for MOE based nomic embed v2.
Signed-off-by: Adam Treat <[email protected]>
1 parent 252c0a7 commit 7e50c4a

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

src/llama-graph.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -901,31 +901,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
901901
}
902902

903903
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
904-
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
905-
cb(up, "ffn_moe_up", il);
904+
ggml_tensor * tmp = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
905+
cb(tmp, "ffn_moe_up", il);
906906

907-
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
908-
cb(gate, "ffn_moe_gate", il);
907+
ggml_tensor * experts = nullptr;
908+
if (gate_exps) {
909+
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
910+
cb(cur, "ffn_moe_gate", il);
911+
} else {
912+
cur = tmp;
913+
}
909914

910915
switch (type_op) {
911916
case LLM_FFN_SILU:
912917
{
913-
gate = ggml_silu(ctx0, gate);
914-
cb(gate, "ffn_moe_silu", il);
918+
cur = ggml_silu(ctx0, cur);
919+
cb(cur, "ffn_moe_silu", il);
915920
} break;
916921
case LLM_FFN_GELU:
917922
{
918-
gate = ggml_gelu(ctx0, gate);
919-
cb(gate, "ffn_moe_gelu", il);
923+
cur = ggml_gelu(ctx0, cur);
924+
cb(cur, "ffn_moe_gelu", il);
920925
} break;
921926
default:
922927
GGML_ABORT("fatal error");
923928
}
924929

925-
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
926-
cb(par, "ffn_moe_gate_par", il);
930+
if (gate_exps) {
931+
cur = ggml_mul(ctx0, cur, tmp); // [n_ff, n_expert_used, n_tokens]
932+
cb(cur, "ffn_moe_gate_par", il);
933+
}
927934

928-
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
935+
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
929936
cb(experts, "ffn_moe_down", il);
930937

931938
experts = ggml_mul(ctx0, experts, weights);

src/llama-model.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5312,6 +5312,11 @@ struct llm_build_bert : public llm_graph_context {
53125312
cur = build_lora_mm(model.layers[il].wqkv, cur);
53135313
cb(cur, "wqkv", il);
53145314

5315+
if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
5316+
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
5317+
cb(cur, "bqkv", il);
5318+
}
5319+
53155320
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
53165321
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
53175322
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
@@ -5364,29 +5369,46 @@ struct llm_build_bert : public llm_graph_context {
53645369
cb(ffn_inp, "ffn_inp", il);
53655370

53665371
// feed-forward network
5367-
if (model.arch == LLM_ARCH_BERT) {
5372+
if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) {
5373+
// MoE branch
5374+
cur = build_moe_ffn(cur,
5375+
model.layers[il].ffn_gate_inp,
5376+
model.layers[il].ffn_up_exps,
5377+
nullptr,
5378+
model.layers[il].ffn_down_exps,
5379+
nullptr,
5380+
hparams.n_expert,
5381+
hparams.n_expert_used,
5382+
LLM_FFN_GELU,
5383+
true, false,
5384+
0.0f,
5385+
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
5386+
cb(cur, "ffn_moe_out", il);
5387+
} else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
53685388
cur = build_ffn(cur,
53695389
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
53705390
NULL, NULL, NULL,
53715391
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
53725392
NULL,
53735393
LLM_FFN_GELU, LLM_FFN_SEQ, il);
5394+
cb(cur, "ffn_out", il);
53745395
} else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
53755396
cur = build_ffn(cur,
53765397
model.layers[il].ffn_up, NULL, NULL,
53775398
model.layers[il].ffn_gate, NULL, NULL,
53785399
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
53795400
NULL,
53805401
LLM_FFN_GELU, LLM_FFN_PAR, il);
5402+
cb(cur, "ffn_out", il);
53815403
} else {
53825404
cur = build_ffn(cur,
53835405
model.layers[il].ffn_up, NULL, NULL,
53845406
model.layers[il].ffn_gate, NULL, NULL,
53855407
model.layers[il].ffn_down, NULL, NULL,
53865408
NULL,
53875409
LLM_FFN_SILU, LLM_FFN_PAR, il);
5410+
cb(cur, "ffn_out", il);
53885411
}
5389-
cb(cur, "ffn_out", il);
53905412

53915413
// attentions bypass the intermediate layer
53925414
cur = ggml_add(ctx0, cur, ffn_inp);

0 commit comments

Comments
 (0)