-
Notifications
You must be signed in to change notification settings - Fork 13.4k
ggml : extend ggml_can_fuse to work with non-sequential nodes #16123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -570,27 +570,27 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n | |
return true; | ||
} | ||
|
||
// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[] | ||
// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[] | ||
// and are fusable. Nodes are considered fusable according to this function if: | ||
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses). | ||
// - all nodes except the last are a src of the following node. | ||
// - all nodes are the same shape. | ||
// TODO: Consider allowing GGML_OP_NONE nodes in between | ||
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) { | ||
if (node_idx + num_ops > cgraph->n_nodes) { | ||
return false; | ||
} | ||
|
||
static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) { | ||
for (int i = 0; i < num_ops; ++i) { | ||
struct ggml_tensor * node = cgraph->nodes[node_idx + i]; | ||
if (node_idxs[i] + num_ops > cgraph->n_nodes) { | ||
return false; | ||
} | ||
|
||
struct ggml_tensor * node = cgraph->nodes[node_idxs[i]]; | ||
if (node->op != ops[i]) { | ||
return false; | ||
} | ||
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) { | ||
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) { | ||
return false; | ||
} | ||
if (i > 0) { | ||
struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1]; | ||
struct ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]]; | ||
if (node->src[0] != prev && node->src[1] != prev) { | ||
return false; | ||
} | ||
|
@@ -602,6 +602,18 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx | |
return true; | ||
} | ||
|
||
// same as above, for sequential indices starting at node_idx | ||
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) { | ||
assert(num_ops < 32); | ||
|
||
int idxs[32]; | ||
for (int i = 0; i < num_ops; ++i) { | ||
idxs[i] = node_idx + i; | ||
} | ||
|
||
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); | ||
} | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
@@ -615,6 +627,11 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std:: | |
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size()); | ||
} | ||
|
||
inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, std::initializer_list<int> node_idx, std::initializer_list<enum ggml_op> ops) { | ||
|
||
assert(node_idx.size() == ops.size()); | ||
return ggml_can_fuse_ext(cgraph, node_idx.begin(), ops.begin(), (int)ops.size()); | ||
} | ||
|
||
// expose GGUF internals for test code | ||
GGML_API size_t gguf_type_size(enum gguf_type type); | ||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add num_ops here? Also, should be >=, i think.