@@ -2821,15 +2821,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28212821 std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops (false );
28222822 std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops (true );
28232823
2824- if (ops.size () == topk_moe_ops_with_norm.size () && std::equal (ops.begin (), ops.end (), topk_moe_ops_with_norm.begin ())) {
2825-
2826- if (node_idx + topk_moe_ops_with_norm.size () > (size_t )cgraph->n_nodes ) {
2827- return false ;
2828- }
2829-
2830- for (size_t i = 0 ; i < topk_moe_ops_with_norm.size (); i++) {
2831- if (cgraph->nodes [node_idx + i]->op != topk_moe_ops_with_norm.begin ()[i]) return false ;
2832- }
2824+ if (ops.size () == topk_moe_ops_with_norm.size () &&
2825+ ggml_can_fuse_subgraph (cgraph, node_idx, topk_moe_ops_with_norm, {node_idx}, {node_idx + 3 , node_idx + 8 })
2826+ ) {
28332827 ggml_tensor * softmax = cgraph->nodes [node_idx];
28342828 ggml_tensor * weights = cgraph->nodes [node_idx+8 ];
28352829
@@ -2838,15 +2832,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28382832 }
28392833 }
28402834
2841- if (ops.size () == topk_moe_ops.size () && std::equal (ops.begin (), ops.end (), topk_moe_ops.begin ())) {
2842-
2843- if (node_idx + topk_moe_ops.size () > (size_t )cgraph->n_nodes ) {
2844- return false ;
2845- }
2846-
2847- for (size_t i = 0 ; i < topk_moe_ops.size (); i++) {
2848- if (cgraph->nodes [node_idx + i]->op != topk_moe_ops.begin ()[i]) return false ;
2849- }
2835+ if (ops.size () == topk_moe_ops.size () && ggml_can_fuse_subgraph (cgraph, node_idx, topk_moe_ops, {node_idx}, {node_idx+3 , node_idx+4 })) {
28502836
28512837 ggml_tensor * softmax = cgraph->nodes [node_idx];
28522838 ggml_tensor * weights = cgraph->nodes [node_idx+4 ];
0 commit comments