@@ -647,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
647647 return ggml_can_fuse_ext (cgraph, idxs, ops, num_ops);
648648}
649649
650+ GGML_API bool ggml_can_fuse_subgraph_ext (const struct ggml_cgraph * cgraph,
651+ const int * node_idxs,
652+ int count,
653+ const enum ggml_op * ops,
654+ const int * outputs,
655+ int num_outputs);
656+
657+ // Returns true if the subgraph formed by {node_idxs} can be fused
658+ // checks whethers all nodes which are not part of outputs can be elided
659+ // by checking if their num_uses are confined to the subgraph
660+ static inline bool ggml_can_fuse_subgraph (const struct ggml_cgraph * cgraph,
661+ int node_idx,
662+ int count,
663+ const enum ggml_op * ops,
664+ const int * outputs,
665+ int num_outputs) {
666+ GGML_ASSERT (count < 32 );
667+ if (node_idx + count > cgraph->n_nodes ) {
668+ return false ;
669+ }
670+
671+ int idxs[32 ];
672+
673+ for (int i = 0 ; i < count; ++i) {
674+ idxs[i] = node_idx + i;
675+ }
676+
677+ return ggml_can_fuse_subgraph_ext (cgraph, idxs, count, ops, outputs, num_outputs);
678+ }
679+
650680#ifdef __cplusplus
651681}
652682#endif
@@ -660,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
660690 return ggml_can_fuse (cgraph, node_idx, ops.begin (), (int )ops.size ());
661691}
662692
693+ inline bool ggml_can_fuse_subgraph (const struct ggml_cgraph * cgraph,
694+ int start_idx,
695+ std::initializer_list<enum ggml_op> ops,
696+ std::initializer_list<int > outputs = {}) {
697+ return ggml_can_fuse_subgraph (cgraph, start_idx, ops.size (), ops.begin (), outputs.begin (), outputs.size ());
698+ }
699+
663700// expose GGUF internals for test code
664701GGML_API size_t gguf_type_size (enum gguf_type type);
665702GGML_API struct gguf_context * gguf_init_from_file_impl (FILE * file, struct gguf_init_params params);
0 commit comments