@@ -6964,6 +6964,86 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
69646964 GGML_LOG_INFO ("========================================\n" );
69656965}
69666966
6967+ static int ggml_find_tensor_node_list (const struct ggml_cgraph * cgraph , const int * idxs , int count , const struct ggml_tensor * tensor ) {
6968+ if (idxs == NULL || cgraph == NULL ) {
6969+ return -1 ;
6970+ }
6971+
6972+ for (int i = 0 ; i < count ; ++ i ) {
6973+ const int node_idx = idxs [count ];
6974+
6975+ if (node_idx >= cgraph -> n_nodes ) {
6976+ return -1 ;
6977+ }
6978+ if (cgraph -> nodes [node_idx ] == tensor ) {
6979+ return i ;
6980+ }
6981+ }
6982+ return -1 ;
6983+ }
6984+
6985+ bool ggml_can_fuse_subgraph_ext (
6986+ const struct ggml_cgraph * cgraph ,
6987+ const int * node_idxs ,
6988+ int count ,
6989+ const enum ggml_op * ops ,
6990+ const int * inputs ,
6991+ int num_inputs ,
6992+ const int * outputs ,
6993+ int num_outputs ) {
6994+
6995+ GGML_ASSERT (count < 32 && num_inputs > 0 && num_outputs > 0 );
6996+ int interior_nodes_count = 0 ;
6997+ int interior_nodes [32 ];
6998+
6999+ for (int i = 0 ; i < count ; ++ i ) {
7000+ if (node_idxs [i ] >= cgraph -> n_nodes || cgraph -> nodes [node_idxs [i ]]-> op != ops [i ]) {
7001+ return false;
7002+ }
7003+
7004+ const struct ggml_tensor * node = cgraph -> nodes [node_idxs [i ]];
7005+
7006+ if (node -> flags & GGML_TENSOR_FLAG_OUTPUT ) {
7007+ return false;
7008+ }
7009+
7010+ if (ggml_find_tensor_node_list (cgraph , inputs , num_inputs , node ) != -1 ) {
7011+ continue ;
7012+ }
7013+
7014+ if (ggml_find_tensor_node_list (cgraph , outputs , num_outputs , node ) != -1 ) {
7015+ continue ;
7016+ }
7017+
7018+ interior_nodes [interior_nodes_count ++ ] = node_idxs [i ];
7019+ }
7020+
7021+ // if interior-node has n-uses, ensure that all of them lie within in this subgraph
7022+ for (int i = 0 ; i < interior_nodes_count ; ++ i ) {
7023+
7024+ const int num_uses = ggml_node_get_use_count (cgraph , interior_nodes [i ]);
7025+
7026+ const struct ggml_tensor * node = cgraph -> nodes [interior_nodes [i ]];
7027+
7028+ int subgraph_uses = 0 ;
7029+ //check if all uses are within the graph
7030+ for (int j = 0 ; j < count ; ++ j ) {
7031+ const struct ggml_tensor * other_node = cgraph -> nodes [node_idxs [j ]];
7032+ for (int src_idx = 0 ; src_idx < GGML_MAX_SRC ; src_idx ++ ) {
7033+ if (other_node -> src [src_idx ] && other_node -> src [src_idx ] == node ) {
7034+ subgraph_uses ++ ;
7035+ }
7036+ }
7037+ }
7038+
7039+ if (subgraph_uses != num_uses ) {
7040+ return false;
7041+ }
7042+ }
7043+
7044+ return true;
7045+ }
7046+
69677047// check if node is part of the graph
69687048static bool ggml_graph_find (const struct ggml_cgraph * cgraph , const struct ggml_tensor * node ) {
69697049 if (cgraph == NULL ) {
0 commit comments