@@ -1140,100 +1140,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
11401140 return moe_out;
11411141}
11421142
1143- ggml_tensor * llm_graph_context::build_mergez_moe_ffn (
1144- ggml_tensor * cur,
1145- ggml_tensor * hidden_state,
1146- ggml_tensor * gate_inp,
1147- ggml_tensor * exp_probs_b,
1148- ggml_tensor * up_exps,
1149- ggml_tensor * gate_exps,
1150- ggml_tensor * down_exps,
1151- int64_t n_expert,
1152- int64_t n_expert_used,
1153- int il) const {
1154- const int64_t n_embd = cur->ne [0 ];
1155- const int64_t n_tokens = cur->ne [1 ];
1156-
1157- ggml_tensor * logits = build_lora_mm (gate_inp, hidden_state); // [n_expert, n_tokens]
1158- cb (logits, " ffn_moe_logits" , il);
1159-
1160- ggml_tensor * normalized_logits = nullptr ;
1161- ggml_tensor * probs = nullptr ;
1162- if (exp_probs_b) {
1163- // For Megrez: sigmoid THEN add bias (not the other way around!)
1164- normalized_logits = ggml_sigmoid (ctx0, logits); // [n_expert, n_tokens]
1165- cb (normalized_logits, " ffn_moe_logits_normalize" , il);
1166- probs = ggml_add (ctx0, normalized_logits, exp_probs_b); // Add bias AFTER sigmoid
1167- cb (probs, " ffn_moe_probs" , il);
1168- } else {
1169- probs = ggml_soft_max (ctx0, logits); // [n_expert, n_tokens]
1170- }
1171-
1172- // select experts
1173- ggml_tensor * selected_experts = ggml_top_k (ctx0, probs, n_expert_used); // [n_expert_used, n_tokens]
1174- cb (selected_experts->src [0 ], " ffn_moe_argsort" , il);
1175- cb (selected_experts, " ffn_moe_topk" , il);
1176-
1177- ggml_tensor * weights = nullptr ;
1178- if (exp_probs_b) {
1179- ggml_tensor * weight0s = ggml_get_rows (ctx0,
1180- ggml_reshape_3d (ctx0, normalized_logits, 1 , n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1181- cb (weight0s, " ffn_moe_weights0" , il);
1182- weight0s = ggml_reshape_2d (ctx0, weight0s, n_expert_used, n_tokens);
1183- ggml_tensor * weights_sum = ggml_sum_rows (ctx0, weight0s); // [1, n_tokens]
1184- cb (weights_sum, " ffn_moe_weights0_sum" , il);
1185- weights = ggml_div (ctx0, weight0s, weights_sum); // [n_expert_used, n_tokens]
1186- cb (weights, " ffn_moe_weights_norm" , il);
1187- weights = ggml_reshape_3d (ctx0, weights, 1 , n_expert_used, n_tokens);
1188- } else {
1189- weights = ggml_get_rows (ctx0,
1190- ggml_reshape_3d (ctx0, probs, 1 , n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
1191- cb (weights, " ffn_moe_weights" , il);
1192- }
1193-
1194- cur = ggml_reshape_3d (ctx0, cur, n_embd, 1 , n_tokens);
1195-
1196- ggml_tensor * up = build_lora_mm_id (up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1197- cb (up, " ffn_moe_up" , il);
1198-
1199- ggml_tensor * gate = build_lora_mm_id (gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
1200- cb (gate, " ffn_moe_gate" , il);
1201-
1202- gate = ggml_silu (ctx0, gate);
1203- cb (gate, " ffn_moe_silu" , il);
1204-
1205- ggml_tensor * par = ggml_mul (ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
1206- cb (par, " ffn_moe_gate_par" , il);
1207-
1208- ggml_tensor * experts = build_lora_mm_id (down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
1209- cb (experts, " ffn_moe_down" , il);
1210-
1211- experts = ggml_mul (ctx0, experts, weights);
1212- cb (experts, " ffn_moe_weighted" , il);
1213-
1214- // aggregate experts
1215- ggml_tensor * moe_out = nullptr ;
1216- for (int i = 0 ; i < n_expert_used; ++i) {
1217- ggml_tensor * cur_expert = ggml_view_2d (ctx0, experts, n_embd, n_tokens,
1218- experts->nb [2 ], i*experts->nb [1 ]);
1219-
1220- if (i == 0 ) {
1221- moe_out = cur_expert;
1222- } else {
1223- moe_out = ggml_add (ctx0, moe_out, cur_expert);
1224- }
1225- }
1226-
1227- if (n_expert_used == 1 ) {
1228- // avoid returning a non-contiguous tensor
1229- moe_out = ggml_cont (ctx0, moe_out);
1230- }
1231-
1232- cb (moe_out, " ffn_moe_out" , il);
1233-
1234- return moe_out;
1235- }
1236-
12371143// input embeddings with optional lora
12381144ggml_tensor * llm_graph_context::build_inp_embd (ggml_tensor * tok_embd) const {
12391145 const int64_t n_embd = hparams.n_embd ;
0 commit comments