@@ -258,6 +258,39 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
258258 }
259259}
260260
261+ void llm_graph_input_expert_mask::set_input (const llama_ubatch * ubatch) {
262+ if (mask == nullptr || (cparams.omit_experts .empty () && cparams.force_experts .empty ())) {
263+ return ;
264+ }
265+ GGML_UNUSED (ubatch);
266+
267+ const int64_t n_expert = mask->ne [0 ];
268+
269+ GGML_ASSERT (ggml_backend_buffer_is_host (mask->buffer ));
270+ float * data = (float *) mask->data ;
271+
272+ std::fill (data, data + n_expert, 0 .0f );
273+
274+ for (int32_t expert_idx : cparams.omit_experts ) {
275+ if (expert_idx >= 0 && expert_idx < n_expert) {
276+ data[expert_idx] = -INFINITY;
277+ }
278+ }
279+ for (int32_t expert_idx : cparams.force_experts ) {
280+ if (expert_idx >= 0 && expert_idx < n_expert) {
281+ data[expert_idx] = INFINITY;
282+ }
283+ }
284+ }
285+
286+ bool llm_graph_input_expert_mask::can_reuse (const llm_graph_params & params) {
287+ bool res = true ;
288+ res &= mask->ne [0 ] == params.hparams .n_expert ;
289+ res &= cparams.omit_experts == params.cparams .omit_experts ;
290+ res &= cparams.force_experts == params.cparams .force_experts ;
291+ return res;
292+ }
293+
261294void llm_graph_input_attn_no_cache::set_input (const llama_ubatch * ubatch) {
262295 const int64_t n_kv = ubatch->n_tokens ;
263296 const int64_t n_tokens = ubatch->n_tokens ;
@@ -787,6 +820,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
787820 bool scale_w,
788821 float w_scale,
789822 llama_expert_gating_func_type gating_op,
823+ ggml_tensor * expert_mask,
790824 int il,
791825 ggml_tensor * probs_in) const {
792826 return build_moe_ffn (
@@ -803,6 +837,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
803837 scale_w,
804838 w_scale,
805839 gating_op,
840+ expert_mask,
806841 il,
807842 probs_in
808843 );
@@ -826,6 +861,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
826861 bool scale_w,
827862 float w_scale,
828863 llama_expert_gating_func_type gating_op,
864+ ggml_tensor * expert_mask,
829865 int il,
830866 ggml_tensor * probs_in) const {
831867 const int64_t n_embd = cur->ne [0 ];
@@ -879,6 +915,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
879915 selection_probs = logits;
880916 }
881917
918+ // Omit or force specified experts by adding a mask of -INF/INF respectively
919+ if (expert_mask != nullptr ) {
920+ selection_probs = ggml_add (ctx0, selection_probs, expert_mask);
921+ cb (selection_probs, " ffn_moe_probs_masked" , il);
922+ }
923+
882924 // select experts
883925 ggml_tensor * selected_experts = ggml_top_k (ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
884926 cb (selected_experts->src [0 ], " ffn_moe_argsort" , il);
@@ -1352,6 +1394,14 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
13521394 return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
13531395}
13541396
1397+ llm_graph_input_expert_mask * llm_graph_context::build_inp_expert_mask () const {
1398+ auto inp = std::make_unique<llm_graph_input_expert_mask>(cparams);
1399+ auto & cur = inp->mask ;
1400+ cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_F32, hparams.n_expert );
1401+ ggml_set_input (cur);
1402+ return (llm_graph_input_expert_mask *) res->add_input (std::move (inp));
1403+ }
1404+
13551405ggml_tensor * llm_graph_context::build_attn (
13561406 llm_graph_input_attn_no_cache * inp,
13571407 ggml_tensor * wo,
0 commit comments