Skip to content

Commit 977a333

Browse files
committed
- combine check into one loop
- check all view_src parents - other minor review comments
1 parent d853036 commit 977a333

File tree

1 file changed

+15
-26
lines changed

1 file changed

+15
-26
lines changed

ggml/src/ggml.c

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6965,13 +6965,10 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
69656965
}
69666966

69676967
static int ggml_find_tensor_node_list(const struct ggml_cgraph * cgraph,
6968-
const int * idxs,
6968+
const int * idxs,
69696969
int count,
69706970
const struct ggml_tensor * tensor) {
6971-
if (idxs == NULL || cgraph == NULL) {
6972-
return -1;
6973-
}
6974-
6971+
GGML_ASSERT(cgraph && idxs);
69756972
for (int i = 0; i < count; ++i) {
69766973
const int node_idx = idxs[i];
69776974

@@ -6992,8 +6989,6 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
69926989
const int * outputs,
69936990
int num_outputs) {
69946991
GGML_ASSERT(count < 32 && outputs && num_outputs > 0);
6995-
int interior_nodes_count = 0;
6996-
int interior_nodes[32];
69976992

69986993
for (int i = 0; i < count; ++i) {
69996994
if (node_idxs[i] >= cgraph->n_nodes || cgraph->nodes[node_idxs[i]]->op != ops[i]) {
@@ -7002,42 +6997,36 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
70026997

70036998
const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
70046999

7005-
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
7006-
return false;
7007-
}
7008-
70097000
if (ggml_find_tensor_node_list(cgraph, outputs, num_outputs, node) != -1) {
70107001
continue;
70117002
}
70127003

7013-
interior_nodes[interior_nodes_count++] = node_idxs[i];
7014-
}
7015-
7016-
for (int i = 0; i < interior_nodes_count; ++i) {
7017-
const int num_uses = ggml_node_get_use_count(cgraph, interior_nodes[i]);
7018-
7019-
const struct ggml_tensor * node = cgraph->nodes[interior_nodes[i]];
7004+
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
7005+
return false;
7006+
}
70207007

7021-
// if interior-node has n-uses, ensure that all of them lie within in this subgraph
70227008
int subgraph_uses = 0;
7023-
for (int j = 0; j < count; ++j) {
7009+
for (int j = i + 1; j < count; ++j) {
70247010
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
70257011
for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
7026-
if (other_node->src[src_idx] && other_node->src[src_idx] == node) {
7012+
if (other_node->src[src_idx] == node) {
70277013
subgraph_uses++;
70287014
}
70297015
}
70307016
}
70317017

7032-
if (subgraph_uses != num_uses) {
7018+
if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
70337019
return false;
70347020
}
70357021

7036-
// if node is a view, check if the view src is within the subgraph
7022+
// if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
70377023
if (node->view_src) {
7038-
const struct ggml_tensor * view_src = node->view_src;
7039-
if (ggml_find_tensor_node_list(cgraph, node_idxs, count, view_src) == -1) {
7040-
return false;
7024+
struct ggml_tensor * view_src = node->view_src;
7025+
while (view_src) {
7026+
if (ggml_find_tensor_node_list(cgraph, node_idxs, count, view_src) == -1) {
7027+
return false;
7028+
}
7029+
view_src = view_src->view_src;
70417030
}
70427031
}
70437032
}

0 commit comments

Comments
 (0)