Skip to content

Commit 578d918

Browse files
committed
ggml-cuda: use ggml_can_fuse_subgraph for topk-moe
1 parent b8a3661 commit 578d918

File tree

1 file changed

+4
-18
lines changed

1 file changed

+4
-18
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2821,15 +2821,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28212821
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
28222822
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
28232823

2824-
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
2825-
2826-
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
2827-
return false;
2828-
}
2829-
2830-
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
2831-
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
2832-
}
2824+
if (ops.size() == topk_moe_ops_with_norm.size() &&
2825+
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, {node_idx}, {node_idx + 3, node_idx + 8})
2826+
) {
28332827
ggml_tensor * softmax = cgraph->nodes[node_idx];
28342828
ggml_tensor * weights = cgraph->nodes[node_idx+8];
28352829

@@ -2838,15 +2832,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28382832
}
28392833
}
28402834

2841-
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
2842-
2843-
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
2844-
return false;
2845-
}
2846-
2847-
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
2848-
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
2849-
}
2835+
if (ops.size() == topk_moe_ops.size() && ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, {node_idx}, {node_idx+3, node_idx+4})) {
28502836

28512837
ggml_tensor * softmax = cgraph->nodes[node_idx];
28522838
ggml_tensor * weights = cgraph->nodes[node_idx+4];

0 commit comments

Comments
 (0)