Skip to content

Commit 6ee701d

Browse files
authored
almost working
1 parent 2ab5409 commit 6ee701d

File tree

1 file changed

+8
-63
lines changed

1 file changed

+8
-63
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -814,63 +814,6 @@ ggml_tensor * llm_graph_context::build_ffn(
814814
return cur;
815815
}
816816

817-
static void mask_expert_groups(struct ggml_tensor * dst, const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) {
818-
GGML_ASSERT(ggml_are_same_shape(dst, a));
819-
GGML_ASSERT(a->type == dst->type);
820-
GGML_ASSERT(a->type == GGML_TYPE_F32);
821-
GGML_ASSERT(b->type == GGML_TYPE_I32);
822-
823-
const int n_expert_groups = static_cast<int>(*((uint32_t *) userdata));
824-
const int n_group_exp = (int)ggml_nelements(b);
825-
826-
const float * a_data = ggml_get_data_f32(a);
827-
const int32_t * b_data = (const int32_t *)ggml_get_data(b);
828-
float * dst_data = ggml_get_data_f32(dst);
829-
830-
// parallelize by groups
831-
const int nr = (int)ggml_nrows(dst);
832-
// number of columns
833-
const int nc = (int)dst->ne[0];
834-
// number of groups per thread
835-
const int dg = (n_expert_groups + nth - 1) / nth;
836-
// group range for this thread
837-
const int ig0 = dg * ith;
838-
const int ig1 = std::min(ig0 + dg, n_expert_groups);
839-
// number of columns per group
840-
const int ncg = nc / n_expert_groups;
841-
// number of bytes per group
842-
const size_t nbg = ncg * dst->nb[0];
843-
844-
// this assumes that the tensors are contiguous
845-
GGML_ASSERT(ggml_is_contiguous(dst));
846-
GGML_ASSERT(ggml_is_contiguous(a));
847-
GGML_ASSERT(ggml_is_contiguous(b));
848-
849-
unsigned int group_sel_mask = 0;
850-
GGML_ASSERT(sizeof(group_sel_mask) * 8 >= static_cast<size_t>(n_expert_groups));
851-
852-
for (int i = 0; i < n_group_exp; ++i) {
853-
const int32_t group_idx = b_data[i];
854-
GGML_ASSERT(group_idx >= 0);
855-
GGML_ASSERT(n_expert_groups > group_idx);
856-
group_sel_mask |= 1 << group_idx;
857-
}
858-
859-
for (int ig = ig0; ig < ig1; ++ig) {
860-
const bool group_sel = ig & group_sel_mask;
861-
862-
for (int ir = 0; ir < nr; ++ir) {
863-
const int i = ir * nc + ig * ncg;
864-
865-
if (group_sel) {
866-
memcpy(dst_data + i, a_data + i, nbg);
867-
} else {
868-
memset(dst_data + i, (uint32_t) -INFINITY, nbg);
869-
}
870-
}
871-
}
872-
}
873-
874817
ggml_tensor * llm_graph_context::build_moe_ffn(
875818
ggml_tensor * cur,
876819
ggml_tensor * gate_inp,
@@ -982,19 +925,21 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
982925
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
983926

984927
// organize experts into n_expert_groups
985-
ggml_tensor * group_scores = ggml_transpose(ctx0, ggml_view_2d(ctx0, selection_probs, hparams.n_expert_groups, n_tokens * n_exp_per_group, selection_probs->nb[1] * n_exp_per_group, 0)); // [n_tokens, n_expert_groups]
986-
group_scores = ggml_top_k(ctx0, ggml_cont(ctx0, group_scores), 2); // [2, n_expert_groups]
987-
group_scores = ggml_get_rows(ctx0, ggml_reshape_3d(ctx0, ggml_cast(ctx0, group_scores, GGML_TYPE_F32), 1, 2, hparams.n_expert_groups), group_scores); // [1, 2, n_expert_groups]
988-
group_scores = ggml_reshape_2d(ctx0, group_scores, 2, hparams.n_expert_groups); // [2, n_expert_groups]
928+
ggml_tensor * selection_groups = ggml_view_2d(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens, n_expert_groups]
929+
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups]
989930

990931
// get top n_group_exp expert groups
991-
group_scores = ggml_transpose(ctx0, ggml_sum_rows(ctx0, group_scores)); // [n_expert_groups, 1]
932+
group_scores = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cast(ctx0, group_scores, GGML_TYPE_F32))); // [n_expert_groups, 1]
992933
ggml_tensor * expert_groups = ggml_top_k(ctx0, ggml_cont(ctx0, group_scores), hparams.n_group_exp); // [n_group_exp, 1]
993934
cb(expert_groups->src[0], "ffn_moe_group_argsort", il);
994935
cb(expert_groups, "ffn_moe_group_topk", il);
995936

996937
// mask out the other groups
997-
selection_probs = ggml_map_custom2(ctx0, selection_probs, expert_groups, mask_expert_groups, GGML_N_TASKS_MAX, (void *)(intptr_t)&hparams.n_expert_groups); // [n_expert, n_tokens]
938+
selection_probs = ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY);
939+
group_scores = ggml_repeat_4d(ctx0, expert_groups, selection_probs->ne[1], 1, 1, 1); // [n_expert_groups, 1]
940+
selection_probs = ggml_set_rows(ctx0, selection_probs, selection_groups, group_scores); // [n_tokens, n_expert_groups]
941+
selection_probs = ggml_view_2d(ctx0, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert]
942+
selection_probs = ggml_cont(ctx0, ggml_transpose(ctx0, selection_probs)); // [n_expert, n_tokens]
998943
cb(selection_probs, "ffn_moe_probs_masked", il);
999944
}
1000945

0 commit comments

Comments
 (0)