@@ -583,27 +583,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
583
583
return true ;
584
584
}
585
585
586
- // Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
586
+ // Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[]
587
587
// and are fusable. Nodes are considered fusable according to this function if:
588
588
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
589
589
// - all nodes except the last are a src of the following node.
590
590
// - all nodes are the same shape.
591
591
// TODO: Consider allowing GGML_OP_NONE nodes in between
592
- static inline bool ggml_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
593
- if (node_idx + num_ops > cgraph->n_nodes ) {
594
- return false ;
595
- }
596
-
592
+ static inline bool ggml_can_fuse_ext (const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) {
597
593
for (int i = 0 ; i < num_ops; ++i) {
598
- struct ggml_tensor * node = cgraph->nodes [node_idx + i];
594
+ if (node_idxs[i] >= cgraph->n_nodes ) {
595
+ return false ;
596
+ }
597
+
598
+ struct ggml_tensor * node = cgraph->nodes [node_idxs[i]];
599
599
if (node->op != ops[i]) {
600
600
return false ;
601
601
}
602
- if (i < num_ops - 1 && !ggml_node_has_n_uses (cgraph, node_idx + i , 1 )) {
602
+ if (i < num_ops - 1 && !ggml_node_has_n_uses (cgraph, node_idxs[i] , 1 )) {
603
603
return false ;
604
604
}
605
605
if (i > 0 ) {
606
- struct ggml_tensor * prev = cgraph->nodes [node_idx + i - 1 ];
606
+ struct ggml_tensor * prev = cgraph->nodes [node_idxs[ i - 1 ] ];
607
607
if (node->src [0 ] != prev && node->src [1 ] != prev) {
608
608
return false ;
609
609
}
@@ -615,6 +615,22 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
615
615
return true ;
616
616
}
617
617
618
+ // same as above, for sequential indices starting at node_idx
619
+ static inline bool ggml_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
620
+ assert (num_ops < 32 );
621
+
622
+ if (node_idx + num_ops > cgraph->n_nodes ) {
623
+ return false ;
624
+ }
625
+
626
+ int idxs[32 ];
627
+ for (int i = 0 ; i < num_ops; ++i) {
628
+ idxs[i] = node_idx + i;
629
+ }
630
+
631
+ return ggml_can_fuse_ext (cgraph, idxs, ops, num_ops);
632
+ }
633
+
618
634
#ifdef __cplusplus
619
635
}
620
636
#endif
0 commit comments