Skip to content

Commit da2fce0

Browse files
committed
C and C++ versions of can_fuse
1 parent e6f3c06 commit da2fce0

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

ggml/src/ggml-impl.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
#include <stdint.h>
1313
#include <string.h>
1414

15-
#ifdef __cplusplus
16-
#include <initializer_list>
17-
#endif
18-
1915
#ifdef __ARM_FEATURE_SVE
2016
#include <arm_sve.h>
2117
#endif // __ARM_FEATURE_SVE
@@ -471,7 +467,6 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
471467
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
472468
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
473469

474-
#ifdef __cplusplus
475470
// return true if the node's results are only used by N other nodes
476471
// and can be fused into their calculations.
477472
static inline bool ggml_node_has_N_uses(const struct ggml_tensor * node, int32_t N) {
@@ -500,15 +495,14 @@ static inline bool ggml_node_has_N_uses(const struct ggml_tensor * node, int32_t
500495
// - all nodes except the last are src[0] of the following node.
501496
// - all nodes are the same shape.
502497
// TODO: Consider allowing GGML_OP_NONE nodes in between
503-
static bool ggml_can_fuse(struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
504-
int num_ops = (int)ops.size();
498+
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op *ops, int num_ops) {
505499
if (node_idx + num_ops > cgraph->n_nodes) {
506500
return false;
507501
}
508502

509503
for (int i = 0; i < num_ops; ++i) {
510504
struct ggml_tensor *node = cgraph->nodes[node_idx + i];
511-
if (node->op != ops.begin()[i]) {
505+
if (node->op != ops[i]) {
512506
return false;
513507
}
514508
if (i < num_ops && !ggml_node_has_N_uses(node, 1)) {
@@ -526,9 +520,17 @@ static bool ggml_can_fuse(struct ggml_cgraph * cgraph, int node_idx, std::initia
526520
}
527521
return true;
528522
}
523+
524+
#ifdef __cplusplus
525+
}
529526
#endif
530527

531528
#ifdef __cplusplus
529+
#include <initializer_list>
530+
531+
// nicer C++ syntax for ggml_can_fuse
532+
inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
533+
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
532534
}
533535
#endif
534536

0 commit comments

Comments
 (0)