Skip to content

Commit 4f324a5

Browse files
authored
ggml : extend ggml_can_fuse to work with non-sequential nodes (ggml-org#16123)
* ggml : extend ggml_can_fuse to work with non-sequential nodes in the graph * cont : fix wrong bounds check condition * cont : remove unnecessary overload
1 parent a71ae3b commit 4f324a5

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

ggml/src/ggml-impl.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -583,27 +583,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
583583
return true;
584584
}
585585

586-
// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
586+
// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[]
587587
// and are fusable. Nodes are considered fusable according to this function if:
588588
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
589589
// - all nodes except the last are a src of the following node.
590590
// - all nodes are the same shape.
591591
// TODO: Consider allowing GGML_OP_NONE nodes in between
592-
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
593-
if (node_idx + num_ops > cgraph->n_nodes) {
594-
return false;
595-
}
596-
592+
static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) {
597593
for (int i = 0; i < num_ops; ++i) {
598-
struct ggml_tensor * node = cgraph->nodes[node_idx + i];
594+
if (node_idxs[i] >= cgraph->n_nodes) {
595+
return false;
596+
}
597+
598+
struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
599599
if (node->op != ops[i]) {
600600
return false;
601601
}
602-
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
602+
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
603603
return false;
604604
}
605605
if (i > 0) {
606-
struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
606+
struct ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
607607
if (node->src[0] != prev && node->src[1] != prev) {
608608
return false;
609609
}
@@ -615,6 +615,22 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
615615
return true;
616616
}
617617

618+
// same as above, for sequential indices starting at node_idx
619+
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
620+
assert(num_ops < 32);
621+
622+
if (node_idx + num_ops > cgraph->n_nodes) {
623+
return false;
624+
}
625+
626+
int idxs[32];
627+
for (int i = 0; i < num_ops; ++i) {
628+
idxs[i] = node_idx + i;
629+
}
630+
631+
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
632+
}
633+
618634
#ifdef __cplusplus
619635
}
620636
#endif

0 commit comments

Comments
 (0)