@@ -6964,7 +6964,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
69646964 GGML_LOG_INFO ("========================================\n" );
69656965}
69666966
6967- static int ggml_find_tensor_node_list (const struct ggml_cgraph * cgraph ,
6967+ static int ggml_node_list_find_tensor (const struct ggml_cgraph * cgraph ,
69686968 const int * idxs ,
69696969 int count ,
69706970 const struct ggml_tensor * tensor ) {
@@ -6988,16 +6988,21 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
69886988 const enum ggml_op * ops ,
69896989 const int * outputs ,
69906990 int num_outputs ) {
6991- GGML_ASSERT (count < 32 && outputs && num_outputs > 0 );
6991+ GGML_ASSERT (outputs && num_outputs > 0 );
69926992
69936993 for (int i = 0 ; i < count ; ++ i ) {
6994- if (node_idxs [i ] >= cgraph -> n_nodes || cgraph -> nodes [node_idxs [i ]]-> op != ops [i ]) {
6994+
6995+ if (node_idxs [i ] >= cgraph -> n_nodes ) {
69956996 return false;
69966997 }
69976998
69986999 const struct ggml_tensor * node = cgraph -> nodes [node_idxs [i ]];
69997000
7000- if (ggml_find_tensor_node_list (cgraph , outputs , num_outputs , node ) != -1 ) {
7001+ if (node -> op != ops [i ]) {
7002+ return false;
7003+ }
7004+
7005+ if (ggml_node_list_find_tensor (cgraph , outputs , num_outputs , node ) != -1 ) {
70017006 continue ;
70027007 }
70037008
@@ -7022,7 +7027,7 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
70227027 // if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
70237028 struct ggml_tensor * view_src = node -> view_src ;
70247029 while (view_src ) {
7025- if (ggml_find_tensor_node_list (cgraph , node_idxs , count , view_src ) == -1 ) {
7030+ if (ggml_node_list_find_tensor (cgraph , node_idxs , count , view_src ) == -1 ) {
70267031 return false;
70277032 }
70287033 view_src = view_src -> view_src ;
0 commit comments