Skip to content

Commit b8a3661

Browse files
committed
ggml: add ggml_can_fuse_subgraph
1 parent cec5edb commit b8a3661

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

ggml/src/ggml-impl.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,42 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
647647
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
648648
}
649649

650+
GGML_API bool ggml_can_fuse_subgraph_ext(
651+
const struct ggml_cgraph * cgraph,
652+
const int * node_idxs,
653+
int count,
654+
const enum ggml_op * ops,
655+
const int * inputs,
656+
int num_inputs,
657+
const int * outputs,
658+
int num_outputs);
659+
660+
// Returns true if the subgraph formed by {node_idxs} can be fused
661+
// checks whethers all nodes which are not part of inputs/outputs can be elided
662+
// by checking if their num_uses are confined to the subgraph
663+
static inline bool ggml_can_fuse_subgraph(
664+
const struct ggml_cgraph * cgraph,
665+
int node_idx,
666+
int count,
667+
const enum ggml_op * ops,
668+
const int * inputs,
669+
int num_inputs,
670+
const int * outputs,
671+
int num_outputs) {
672+
673+
if (node_idx + count > cgraph->n_nodes) {
674+
return false;
675+
}
676+
677+
int idxs[32];
678+
679+
for (int i = 0; i < count; ++i) {
680+
idxs[i] = node_idx + i;
681+
}
682+
683+
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, inputs, num_inputs, outputs, num_outputs);
684+
}
685+
650686
#ifdef __cplusplus
651687
}
652688
#endif
@@ -660,6 +696,23 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
660696
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
661697
}
662698

699+
inline bool ggml_can_fuse_subgraph(
700+
const struct ggml_cgraph * cgraph,
701+
int start_idx,
702+
std::initializer_list<enum ggml_op> ops,
703+
std::initializer_list<int> inputs = {},
704+
std::initializer_list<int> outputs = {}) {
705+
return ggml_can_fuse_subgraph(
706+
cgraph,
707+
start_idx,
708+
ops.size(),
709+
ops.begin(),
710+
inputs.begin(),
711+
inputs.size(),
712+
outputs.begin(),
713+
outputs.size());
714+
}
715+
663716
// expose GGUF internals for test code
664717
GGML_API size_t gguf_type_size(enum gguf_type type);
665718
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);

ggml/src/ggml.c

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6964,6 +6964,86 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
69646964
GGML_LOG_INFO("========================================\n");
69656965
}
69666966

6967+
static int ggml_find_tensor_node_list(const struct ggml_cgraph * cgraph, const int * idxs, int count, const struct ggml_tensor * tensor) {
6968+
if (idxs == NULL || cgraph == NULL) {
6969+
return -1;
6970+
}
6971+
6972+
for(int i = 0; i < count; ++i) {
6973+
const int node_idx = idxs[count];
6974+
6975+
if (node_idx >= cgraph->n_nodes) {
6976+
return -1;
6977+
}
6978+
if (cgraph->nodes[node_idx] == tensor) {
6979+
return i;
6980+
}
6981+
}
6982+
return -1;
6983+
}
6984+
6985+
bool ggml_can_fuse_subgraph_ext(
6986+
const struct ggml_cgraph * cgraph,
6987+
const int * node_idxs,
6988+
int count,
6989+
const enum ggml_op * ops,
6990+
const int * inputs,
6991+
int num_inputs,
6992+
const int * outputs,
6993+
int num_outputs) {
6994+
6995+
GGML_ASSERT(count < 32 && num_inputs > 0 && num_outputs > 0);
6996+
int interior_nodes_count = 0;
6997+
int interior_nodes[32];
6998+
6999+
for(int i = 0 ; i < count; ++i) {
7000+
if (node_idxs[i] >= cgraph->n_nodes || cgraph->nodes[node_idxs[i]]->op != ops[i]) {
7001+
return false;
7002+
}
7003+
7004+
const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
7005+
7006+
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
7007+
return false;
7008+
}
7009+
7010+
if (ggml_find_tensor_node_list(cgraph, inputs, num_inputs, node) != -1) {
7011+
continue;
7012+
}
7013+
7014+
if (ggml_find_tensor_node_list(cgraph, outputs, num_outputs, node) != -1) {
7015+
continue;
7016+
}
7017+
7018+
interior_nodes[interior_nodes_count++] = node_idxs[i];
7019+
}
7020+
7021+
// if interior-node has n-uses, ensure that all of them lie within in this subgraph
7022+
for(int i = 0 ; i < interior_nodes_count; ++i) {
7023+
7024+
const int num_uses = ggml_node_get_use_count(cgraph, interior_nodes[i]);
7025+
7026+
const struct ggml_tensor * node = cgraph->nodes[interior_nodes[i]];
7027+
7028+
int subgraph_uses = 0;
7029+
//check if all uses are within the graph
7030+
for(int j = 0; j < count; ++j) {
7031+
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
7032+
for(int src_idx = 0 ; src_idx < GGML_MAX_SRC; src_idx++) {
7033+
if (other_node->src[src_idx] && other_node->src[src_idx] == node) {
7034+
subgraph_uses++;
7035+
}
7036+
}
7037+
}
7038+
7039+
if (subgraph_uses != num_uses) {
7040+
return false;
7041+
}
7042+
}
7043+
7044+
return true;
7045+
}
7046+
69677047
// check if node is part of the graph
69687048
static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
69697049
if (cgraph == NULL) {

0 commit comments

Comments
 (0)