@@ -814,63 +814,6 @@ ggml_tensor * llm_graph_context::build_ffn(
814814 return cur;
815815}
816816
817- static void mask_expert_groups (struct ggml_tensor * dst, const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) {
818- GGML_ASSERT (ggml_are_same_shape (dst, a));
819- GGML_ASSERT (a->type == dst->type );
820- GGML_ASSERT (a->type == GGML_TYPE_F32);
821- GGML_ASSERT (b->type == GGML_TYPE_I32);
822-
823- const int n_expert_groups = static_cast <int >(*((uint32_t *) userdata));
824- const int n_group_exp = (int )ggml_nelements (b);
825-
826- const float * a_data = ggml_get_data_f32 (a);
827- const int32_t * b_data = (const int32_t *)ggml_get_data (b);
828- float * dst_data = ggml_get_data_f32 (dst);
829-
830- // parallelize by groups
831- const int nr = (int )ggml_nrows (dst);
832- // number of columns
833- const int nc = (int )dst->ne [0 ];
834- // number of groups per thread
835- const int dg = (n_expert_groups + nth - 1 ) / nth;
836- // group range for this thread
837- const int ig0 = dg * ith;
838- const int ig1 = std::min (ig0 + dg, n_expert_groups);
839- // number of columns per group
840- const int ncg = nc / n_expert_groups;
841- // number of bytes per group
842- const size_t nbg = ncg * dst->nb [0 ];
843-
844- // this assumes that the tensors are contiguous
845- GGML_ASSERT (ggml_is_contiguous (dst));
846- GGML_ASSERT (ggml_is_contiguous (a));
847- GGML_ASSERT (ggml_is_contiguous (b));
848-
849- unsigned int group_sel_mask = 0 ;
850- GGML_ASSERT (sizeof (group_sel_mask) * 8 >= static_cast <size_t >(n_expert_groups));
851-
852- for (int i = 0 ; i < n_group_exp; ++i) {
853- const int32_t group_idx = b_data[i];
854- GGML_ASSERT (group_idx >= 0 );
855- GGML_ASSERT (n_expert_groups > group_idx);
856- group_sel_mask |= 1 << group_idx;
857- }
858-
859- for (int ig = ig0; ig < ig1; ++ig) {
860- const bool group_sel = ig & group_sel_mask;
861-
862- for (int ir = 0 ; ir < nr; ++ir) {
863- const int i = ir * nc + ig * ncg;
864-
865- if (group_sel) {
866- memcpy (dst_data + i, a_data + i, nbg);
867- } else {
868- memset (dst_data + i, (uint32_t ) -INFINITY, nbg);
869- }
870- }
871- }
872- }
873-
874817ggml_tensor * llm_graph_context::build_moe_ffn (
875818 ggml_tensor * cur,
876819 ggml_tensor * gate_inp,
@@ -982,19 +925,21 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
982925 const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups ;
983926
984927 // organize experts into n_expert_groups
985- ggml_tensor * group_scores = ggml_transpose (ctx0, ggml_view_2d (ctx0, selection_probs, hparams.n_expert_groups , n_tokens * n_exp_per_group, selection_probs->nb [1 ] * n_exp_per_group, 0 )); // [n_tokens, n_expert_groups]
986- group_scores = ggml_top_k (ctx0, ggml_cont (ctx0, group_scores), 2 ); // [2, n_expert_groups]
987- group_scores = ggml_get_rows (ctx0, ggml_reshape_3d (ctx0, ggml_cast (ctx0, group_scores, GGML_TYPE_F32), 1 , 2 , hparams.n_expert_groups ), group_scores); // [1, 2, n_expert_groups]
988- group_scores = ggml_reshape_2d (ctx0, group_scores, 2 , hparams.n_expert_groups ); // [2, n_expert_groups]
928+ ggml_tensor * selection_groups = ggml_view_2d (ctx0, ggml_cont (ctx0, ggml_transpose (ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups , n_tokens * n_exp_per_group * sizeof (float ), 0 ); // [n_tokens, n_expert_groups]
929+ ggml_tensor * group_scores = ggml_top_k (ctx0, selection_groups, 2 ); // [2, n_expert_groups]
989930
990931 // get top n_group_exp expert groups
991- group_scores = ggml_transpose (ctx0, ggml_sum_rows (ctx0, group_scores)); // [n_expert_groups, 1]
932+ group_scores = ggml_transpose (ctx0, ggml_sum_rows (ctx0, ggml_cast (ctx0, group_scores, GGML_TYPE_F32) )); // [n_expert_groups, 1]
992933 ggml_tensor * expert_groups = ggml_top_k (ctx0, ggml_cont (ctx0, group_scores), hparams.n_group_exp ); // [n_group_exp, 1]
993934 cb (expert_groups->src [0 ], " ffn_moe_group_argsort" , il);
994935 cb (expert_groups, " ffn_moe_group_topk" , il);
995936
996937 // mask out the other groups
997- selection_probs = ggml_map_custom2 (ctx0, selection_probs, expert_groups, mask_expert_groups, GGML_N_TASKS_MAX, (void *)(intptr_t )&hparams.n_expert_groups ); // [n_expert, n_tokens]
938+ selection_probs = ggml_scale_bias (ctx0, selection_groups, 0 .0f , -INFINITY);
939+ group_scores = ggml_repeat_4d (ctx0, expert_groups, selection_probs->ne [1 ], 1 , 1 , 1 ); // [n_expert_groups, 1]
940+ selection_probs = ggml_set_rows (ctx0, selection_probs, selection_groups, group_scores); // [n_tokens, n_expert_groups]
941+ selection_probs = ggml_view_2d (ctx0, selection_probs, n_tokens, n_expert, n_tokens * sizeof (float ), 0 ); // [n_tokens, n_expert]
942+ selection_probs = ggml_cont (ctx0, ggml_transpose (ctx0, selection_probs)); // [n_expert, n_tokens]
998943 cb (selection_probs, " ffn_moe_probs_masked" , il);
999944 }
1000945
0 commit comments