@@ -928,24 +928,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
928928 cb (selection_probs, " ffn_moe_probs_biased" , il);
929929 }
930930
931- // select top n_group_exp expert groups
931+ // select top n_group_used expert groups
932932 if (arch == LLM_ARCH_BAILINGMOE2) {
933933 const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups ;
934934
935935 // organize experts into n_expert_groups
936- ggml_tensor * selection_groups = ggml_view_2d (ctx0, ggml_cont (ctx0, ggml_transpose (ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups , n_tokens * n_exp_per_group * sizeof (float ), 0 ); // [n_tokens, n_expert_groups]
936+ ggml_tensor * selection_groups = ggml_view_2d (ctx0, ggml_cont (ctx0, ggml_transpose (ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups , n_tokens * n_exp_per_group * sizeof (float ), 0 ); // [n_tokens * n_exp_per_group , n_expert_groups]
937937 ggml_tensor * group_scores = ggml_top_k (ctx0, selection_groups, 2 ); // [2, n_expert_groups]
938+ group_scores = ggml_get_rows (ctx0, ggml_reshape_3d (ctx0, selection_groups, 1 , selection_groups->ne [0 ], selection_groups->ne [1 ]), group_scores); // [1, 2, n_expert_groups]
938939
939- // get top n_group_exp expert groups
940- group_scores = ggml_transpose (ctx0, ggml_sum_rows (ctx0, ggml_cast (ctx0, group_scores, GGML_TYPE_F32 ))); // [n_expert_groups, 1]
941- ggml_tensor * expert_groups = ggml_top_k (ctx0, ggml_cont (ctx0, group_scores), hparams.n_group_exp ); // [n_group_exp , 1]
940+ // get top n_group_used expert groups
941+ group_scores = ggml_transpose (ctx0, ggml_sum_rows (ctx0, ggml_reshape_2d (ctx0, group_scores, group_scores-> ne [ 1 ], group_scores-> ne [ 2 ] ))); // [n_expert_groups, 1]
942+ ggml_tensor * expert_groups = ggml_top_k (ctx0, ggml_cont (ctx0, group_scores), hparams.n_group_used ); // [n_group_used , 1]
942943 cb (expert_groups->src [0 ], " ffn_moe_group_argsort" , il);
943944 cb (expert_groups, " ffn_moe_group_topk" , il);
944945
945946 // mask out the other groups
946- selection_probs = ggml_scale_bias (ctx0, selection_groups, 0 .0f , -INFINITY);
947- group_scores = ggml_repeat_4d (ctx0, expert_groups, selection_probs->ne [1 ], 1 , 1 , 1 ); // [n_expert_groups, 1]
948- selection_probs = ggml_set_rows (ctx0, selection_probs, selection_groups, group_scores); // [n_tokens, n_expert_groups]
947+ selection_probs = ggml_get_rows (ctx0, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
948+ selection_probs = ggml_set_rows (ctx0, ggml_scale_bias (ctx0, selection_groups, 0 .0f , -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
949949 selection_probs = ggml_view_2d (ctx0, selection_probs, n_tokens, n_expert, n_tokens * sizeof (float ), 0 ); // [n_tokens, n_expert]
950950 selection_probs = ggml_cont (ctx0, ggml_transpose (ctx0, selection_probs)); // [n_expert, n_tokens]
951951 cb (selection_probs, " ffn_moe_probs_masked" , il);
0 commit comments