@@ -6965,13 +6965,10 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
69656965}
69666966
69676967static int ggml_find_tensor_node_list (const struct ggml_cgraph * cgraph ,
6968- const int * idxs ,
6968+ const int * idxs ,
69696969 int count ,
69706970 const struct ggml_tensor * tensor ) {
6971- if (idxs == NULL || cgraph == NULL ) {
6972- return -1 ;
6973- }
6974-
6971+ GGML_ASSERT (cgraph && idxs );
69756972 for (int i = 0 ; i < count ; ++ i ) {
69766973 const int node_idx = idxs [i ];
69776974
@@ -6992,8 +6989,6 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
69926989 const int * outputs ,
69936990 int num_outputs ) {
69946991 GGML_ASSERT (count < 32 && outputs && num_outputs > 0 );
6995- int interior_nodes_count = 0 ;
6996- int interior_nodes [32 ];
69976992
69986993 for (int i = 0 ; i < count ; ++ i ) {
69996994 if (node_idxs [i ] >= cgraph -> n_nodes || cgraph -> nodes [node_idxs [i ]]-> op != ops [i ]) {
@@ -7002,42 +6997,36 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
70026997
70036998 const struct ggml_tensor * node = cgraph -> nodes [node_idxs [i ]];
70046999
7005- if (node -> flags & GGML_TENSOR_FLAG_OUTPUT ) {
7006- return false;
7007- }
7008-
70097000 if (ggml_find_tensor_node_list (cgraph , outputs , num_outputs , node ) != -1 ) {
70107001 continue ;
70117002 }
70127003
7013- interior_nodes [interior_nodes_count ++ ] = node_idxs [i ];
7014- }
7015-
7016- for (int i = 0 ; i < interior_nodes_count ; ++ i ) {
7017- const int num_uses = ggml_node_get_use_count (cgraph , interior_nodes [i ]);
7018-
7019- const struct ggml_tensor * node = cgraph -> nodes [interior_nodes [i ]];
7004+ if (node -> flags & GGML_TENSOR_FLAG_OUTPUT ) {
7005+ return false;
7006+ }
70207007
7021- // if interior-node has n-uses, ensure that all of them lie within in this subgraph
70227008 int subgraph_uses = 0 ;
7023- for (int j = 0 ; j < count ; ++ j ) {
7009+ for (int j = i + 1 ; j < count ; ++ j ) {
70247010 const struct ggml_tensor * other_node = cgraph -> nodes [node_idxs [j ]];
70257011 for (int src_idx = 0 ; src_idx < GGML_MAX_SRC ; src_idx ++ ) {
7026- if (other_node -> src [src_idx ] && other_node -> src [ src_idx ] == node ) {
7012+ if (other_node -> src [src_idx ] == node ) {
70277013 subgraph_uses ++ ;
70287014 }
70297015 }
70307016 }
70317017
7032- if (subgraph_uses != num_uses ) {
7018+ if (subgraph_uses != ggml_node_get_use_count ( cgraph , node_idxs [ i ]) ) {
70337019 return false;
70347020 }
70357021
7036- // if node is a view, check if the view src is within the subgraph
7022+ // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
70377023 if (node -> view_src ) {
7038- const struct ggml_tensor * view_src = node -> view_src ;
7039- if (ggml_find_tensor_node_list (cgraph , node_idxs , count , view_src ) == -1 ) {
7040- return false;
7024+ struct ggml_tensor * view_src = node -> view_src ;
7025+ while (view_src ) {
7026+ if (ggml_find_tensor_node_list (cgraph , node_idxs , count , view_src ) == -1 ) {
7027+ return false;
7028+ }
7029+ view_src = view_src -> view_src ;
70417030 }
70427031 }
70437032 }
0 commit comments