Skip to content

Commit 52d7866

Browse files
am17anpwilkin
authored andcommitted
ggml: add ggml_can_fuse_subgraph (ggml-org#16662)
* ggml: add ggml_can_fuse_subgraph * ggml-cuda: use ggml_can_fuse_subgraph for topk-moe * format * 1. remove inputs from signature as they are transient nodes 2. add check for views: view_src should be part of the subgraph * - combine check into one loop - check all view_src parents - other minor review comments * remove redudant if test * - rename and other minor review comments * add assert about count < 32
1 parent 7ad39a3 commit 52d7866

File tree

3 files changed

+113
-19
lines changed

3 files changed

+113
-19
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,15 +2827,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28272827
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
28282828
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
28292829

2830-
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
2831-
2832-
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
2833-
return false;
2834-
}
2835-
2836-
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
2837-
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
2838-
}
2830+
if (ops.size() == topk_moe_ops_with_norm.size() &&
2831+
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
28392832
ggml_tensor * softmax = cgraph->nodes[node_idx];
28402833
ggml_tensor * weights = cgraph->nodes[node_idx+8];
28412834

@@ -2844,16 +2837,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28442837
}
28452838
}
28462839

2847-
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
2848-
2849-
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
2850-
return false;
2851-
}
2852-
2853-
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
2854-
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
2855-
}
2856-
2840+
if (ops.size() == topk_moe_ops.size() &&
2841+
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
28572842
ggml_tensor * softmax = cgraph->nodes[node_idx];
28582843
ggml_tensor * weights = cgraph->nodes[node_idx+4];
28592844
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {

ggml/src/ggml-impl.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
647647
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
648648
}
649649

650+
GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
651+
const int * node_idxs,
652+
int count,
653+
const enum ggml_op * ops,
654+
const int * outputs,
655+
int num_outputs);
656+
657+
// Returns true if the subgraph formed by {node_idxs} can be fused
658+
// checks whethers all nodes which are not part of outputs can be elided
659+
// by checking if their num_uses are confined to the subgraph
660+
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
661+
int node_idx,
662+
int count,
663+
const enum ggml_op * ops,
664+
const int * outputs,
665+
int num_outputs) {
666+
GGML_ASSERT(count < 32);
667+
if (node_idx + count > cgraph->n_nodes) {
668+
return false;
669+
}
670+
671+
int idxs[32];
672+
673+
for (int i = 0; i < count; ++i) {
674+
idxs[i] = node_idx + i;
675+
}
676+
677+
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
678+
}
679+
650680
#ifdef __cplusplus
651681
}
652682
#endif
@@ -660,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
660690
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
661691
}
662692

693+
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
694+
int start_idx,
695+
std::initializer_list<enum ggml_op> ops,
696+
std::initializer_list<int> outputs = {}) {
697+
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
698+
}
699+
663700
// expose GGUF internals for test code
664701
GGML_API size_t gguf_type_size(enum gguf_type type);
665702
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

ggml/src/ggml.c

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7063,6 +7063,78 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
70637063
GGML_LOG_INFO("========================================\n");
70647064
}
70657065

7066+
static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
7067+
const int * idxs,
7068+
int count,
7069+
const struct ggml_tensor * tensor) {
7070+
GGML_ASSERT(cgraph && idxs);
7071+
for (int i = 0; i < count; ++i) {
7072+
const int node_idx = idxs[i];
7073+
7074+
if (node_idx >= cgraph->n_nodes) {
7075+
return -1;
7076+
}
7077+
if (cgraph->nodes[node_idx] == tensor) {
7078+
return i;
7079+
}
7080+
}
7081+
return -1;
7082+
}
7083+
7084+
bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
7085+
const int * node_idxs,
7086+
int count,
7087+
const enum ggml_op * ops,
7088+
const int * outputs,
7089+
int num_outputs) {
7090+
GGML_ASSERT(outputs && num_outputs > 0);
7091+
7092+
for (int i = 0; i < count; ++i) {
7093+
if (node_idxs[i] >= cgraph->n_nodes) {
7094+
return false;
7095+
}
7096+
7097+
const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
7098+
7099+
if (node->op != ops[i]) {
7100+
return false;
7101+
}
7102+
7103+
if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
7104+
continue;
7105+
}
7106+
7107+
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
7108+
return false;
7109+
}
7110+
7111+
int subgraph_uses = 0;
7112+
for (int j = i + 1; j < count; ++j) {
7113+
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
7114+
for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
7115+
if (other_node->src[src_idx] == node) {
7116+
subgraph_uses++;
7117+
}
7118+
}
7119+
}
7120+
7121+
if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
7122+
return false;
7123+
}
7124+
7125+
// if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
7126+
struct ggml_tensor * view_src = node->view_src;
7127+
while (view_src) {
7128+
if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
7129+
return false;
7130+
}
7131+
view_src = view_src->view_src;
7132+
}
7133+
}
7134+
7135+
return true;
7136+
}
7137+
70667138
// check if node is part of the graph
70677139
static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
70687140
if (cgraph == NULL) {

0 commit comments

Comments
 (0)