@@ -570,27 +570,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
570570 return true ;
571571}
572572
573- // Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
573+ // Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[]
574574// and are fusable. Nodes are considered fusable according to this function if:
575575// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
576576// - all nodes except the last are a src of the following node.
577577// - all nodes are the same shape.
578578// TODO: Consider allowing GGML_OP_NONE nodes in between
579- static inline bool ggml_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
580- if (node_idx + num_ops > cgraph->n_nodes ) {
581- return false ;
582- }
583-
579+ static inline bool ggml_can_fuse_ext (const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) {
584580 for (int i = 0 ; i < num_ops; ++i) {
585- struct ggml_tensor * node = cgraph->nodes [node_idx + i];
581+ if (node_idxs[i] + num_ops > cgraph->n_nodes ) {
582+ return false ;
583+ }
584+
585+ struct ggml_tensor * node = cgraph->nodes [node_idxs[i]];
586586 if (node->op != ops[i]) {
587587 return false ;
588588 }
589- if (i < num_ops - 1 && !ggml_node_has_n_uses (cgraph, node_idx + i , 1 )) {
589+ if (i < num_ops - 1 && !ggml_node_has_n_uses (cgraph, node_idxs[i] , 1 )) {
590590 return false ;
591591 }
592592 if (i > 0 ) {
593- struct ggml_tensor * prev = cgraph->nodes [node_idx + i - 1 ];
593+ struct ggml_tensor * prev = cgraph->nodes [node_idxs[ i - 1 ] ];
594594 if (node->src [0 ] != prev && node->src [1 ] != prev) {
595595 return false ;
596596 }
@@ -602,6 +602,18 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
602602 return true ;
603603}
604604
605+ // same as above, for sequential indices starting at node_idx
606+ static inline bool ggml_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
607+ assert (num_ops < 32 );
608+
609+ int idxs[32 ];
610+ for (int i = 0 ; i < num_ops; ++i) {
611+ idxs[i] = node_idx + i;
612+ }
613+
614+ return ggml_can_fuse_ext (cgraph, idxs, ops, num_ops);
615+ }
616+
605617#ifdef __cplusplus
606618}
607619#endif
@@ -615,6 +627,11 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
615627 return ggml_can_fuse (cgraph, node_idx, ops.begin (), (int )ops.size ());
616628}
617629
630+ inline bool ggml_can_fuse (const struct ggml_cgraph * cgraph, std::initializer_list<int > node_idx, std::initializer_list<enum ggml_op> ops) {
631+ assert (node_idx.size () == ops.size ());
632+ return ggml_can_fuse_ext (cgraph, node_idx.begin (), ops.begin (), (int )ops.size ());
633+ }
634+
618635// expose GGUF internals for test code
619636GGML_API size_t gguf_type_size (enum gguf_type type);
620637GGML_API struct gguf_context * gguf_init_from_file_impl (FILE * file, struct gguf_init_params params);
0 commit comments