Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,30 +933,23 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;

// organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_view_2d(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups]
#if 0
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups]
group_scores = ggml_get_rows(ctx0, ggml_reshape_3d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups]
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]

// get top n_group_used expert groups
group_scores = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1]
#else
// Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead
ggml_tensor * group_scores = ggml_reshape_2d(ctx0, ggml_argmax(ctx0, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups]
group_scores = ggml_get_rows(ctx0, ggml_reshape_3d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups]
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]

// get top n_group_used expert groups
group_scores = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1]
#endif
ggml_tensor * expert_groups = ggml_top_k(ctx0, ggml_cont(ctx0, group_scores), hparams.n_group_used); // [n_group_used, 1]
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]

ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
cb(expert_groups->src[0], "ffn_moe_group_argsort", il);
cb(expert_groups, "ffn_moe_group_topk", il);

// mask out the other groups
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used]
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups]
selection_probs = ggml_view_2d(ctx0, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert]
selection_probs = ggml_cont(ctx0, ggml_transpose(ctx0, selection_probs)); // [n_expert, n_tokens]
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
cb(selection_probs, "ffn_moe_probs_masked", il);
}

Expand Down
Loading