diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 02147a0ea4a52..052efb7ace50d 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -567,13 +567,13 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { ctx->debug_graph, ctx->debug_fusion); - for (int idx = idx_start; idx < idx_end;) { + for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) { const int res = ggml_metal_op_encode(ctx_op, idx); if (res == 0) { break; } - idx += res; + idx += res - 1; } ggml_metal_op_free(ctx_op); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7b11f36adccdd..458559092e0f6 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -24,22 +24,88 @@ static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { } struct ggml_metal_op { + ggml_metal_op( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion) { + this->dev = dev; + this->lib = ggml_metal_device_get_library(dev); + this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency); + this->mem_ranges = ggml_mem_ranges_init(debug_graph); + this->idx_start = idx_start; + this->idx_end = idx_end; + this->use_fusion = use_fusion; + this->use_concurrency = use_concurrency; + this->use_capture = use_capture; + this->debug_graph = debug_graph; + this->debug_fusion = debug_fusion; + this->gf = gf; + + idxs.reserve(gf->n_nodes); + + // filter empty nodes + // TODO: this can be removed when the allocator starts filtering them earlier + // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830 + for (int i = idx_start; i < idx_end; i++) { + if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) { + idxs.push_back(i); + } + } + } + + ~ggml_metal_op() { + ggml_metal_encoder_end_encoding(this->enc); + ggml_metal_encoder_free(this->enc); + ggml_mem_ranges_free(this->mem_ranges); + } + + int n_nodes() const { + return idxs.size(); + } + + ggml_tensor * node(int i) const { + assert(i >= 0 && i < (int) idxs.size()); + return ggml_graph_node(gf, idxs[i]); + } + + bool can_fuse(int i0, const ggml_op * ops, int n_ops) const { + assert(use_fusion); + assert(i0 >= 0 && i0 < n_nodes()); + + if (i0 + n_ops > n_nodes()) { + return false; + } + + return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops); + } + ggml_metal_device_t dev; ggml_metal_library_t lib; ggml_metal_encoder_t enc; ggml_mem_ranges_t mem_ranges; - ggml_cgraph * gf; - - int idx_start; - int idx_end; - bool use_fusion; bool use_concurrency; bool use_capture; int debug_graph; int debug_fusion; + +private: + ggml_cgraph * gf; + + int idx_start; + int idx_end; + + // non-empty node indices + std::vector idxs; }; ggml_metal_op_t ggml_metal_op_init( @@ -53,34 +119,29 @@ ggml_metal_op_t ggml_metal_op_init( bool use_capture, int debug_graph, int debug_fusion) { - ggml_metal_op_t res = new ggml_metal_op(); - - *res = { - /*.dev =*/ dev, - /*.lib =*/ ggml_metal_device_get_library(dev), - /*.enc =*/ ggml_metal_encoder_init(cmd_buf, use_concurrency), - /*.mem_ranges =*/ ggml_mem_ranges_init(debug_graph), - /*.gf =*/ gf, - /*.idx_start =*/ idx_start, - /*.idx_end =*/ idx_end, - /*.use_fusion =*/ use_fusion, - /*.use_concurrency =*/ use_concurrency, - /*.use_capture =*/ use_capture, - /*.debug_graph =*/ debug_graph, - /*.debug_fusion =*/ debug_fusion, - }; + ggml_metal_op_t res = new ggml_metal_op( + dev, + cmd_buf, + gf, + idx_start, + idx_end, + use_fusion, + use_concurrency, + use_capture, + debug_graph, + debug_fusion); return res; } void ggml_metal_op_free(ggml_metal_op_t ctx) { - ggml_metal_encoder_end_encoding(ctx->enc); - ggml_metal_encoder_free(ctx->enc); - ggml_mem_ranges_free(ctx->mem_ranges); - delete ctx; } +int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) { + return ctx->n_nodes(); +} + static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) { if (!ctx->mem_ranges) { return true; @@ -110,10 +171,7 @@ static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor } static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { - struct ggml_cgraph * gf = ctx->gf; - - struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx; - struct ggml_tensor * node = nodes[0]; + struct ggml_tensor * node = ctx->node(idx); //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); @@ -129,6 +187,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { case GGML_OP_PERMUTE: { // noop -> next node + if (ctx->debug_graph > 0) { + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)"); + } } return 1; default: { @@ -352,7 +413,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { // update the mem ranges in the encoding context for (int i = 0; i < n_fuse; ++i) { - if (!ggml_metal_op_concurrency_add(ctx, nodes[i])) { + if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) { ggml_metal_op_concurrency_reset(ctx); } } @@ -362,11 +423,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) { if (ctx->use_capture) { - ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ggml_graph_node(ctx->gf, idx))); + ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx))); } int res = ggml_metal_op_encode_impl(ctx, idx); - if (idx + res > ctx->idx_end) { + if (idx + res > ctx->n_nodes()) { GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", "https://github.com/ggml-org/llama.cpp/pull/14849"); } @@ -379,8 +440,7 @@ int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -438,8 +498,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -483,8 +542,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -594,8 +652,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -634,8 +691,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -674,8 +730,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -703,8 +758,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -774,8 +828,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -838,8 +891,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -876,8 +928,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -939,8 +990,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1030,8 +1080,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1076,8 +1125,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1170,8 +1218,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1212,8 +1259,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1286,8 +1332,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1347,8 +1392,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1589,8 +1633,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) { } int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1783,8 +1826,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1856,8 +1898,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { } int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2176,16 +2217,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); - - ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; - const int idx_end = ctx->idx_end; - const bool use_fusion = ctx->use_fusion; const int debug_fusion = ctx->debug_fusion; @@ -2258,22 +2294,25 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops // across splits. idx_end indicates the last node in the current split - for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { + for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { + if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { break; } - if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { + ggml_tensor * f0 = ctx->node(idx + n_fuse); + ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); + + if (f0 != f1->src[0]) { break; } // b[0] === b[1] === ... - if (!ggml_are_same_layout(ops[n_fuse]->src[1], ops[n_fuse + 1]->src[1])) { + if (!ggml_are_same_layout(f0->src[1], f1->src[1])) { break; } // only fuse ops if src1 is in the same Metal buffer - ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); + ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]); if (bid_fuse.metal != bid_src1.metal) { break; } @@ -2309,10 +2348,10 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } if (n_fuse > 1) { - bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); + bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); for (int i = 1; i < n_fuse; ++i) { - if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { + if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { ggml_metal_op_concurrency_reset(ctx); break; @@ -2344,8 +2383,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2393,8 +2431,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2445,20 +2482,15 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; - const int idx_end = ctx->idx_end; - const bool use_fusion = ctx->use_fusion; const int debug_fusion = ctx->debug_fusion; - ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); @@ -2499,38 +2531,41 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { fops[1] = GGML_OP_MUL; fops[2] = GGML_OP_ADD; - for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { + for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { + if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { break; } - if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { + ggml_tensor * f0 = ctx->node(idx + n_fuse); + ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); + + if (f0 != f1->src[0]) { break; } - if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) { + if (f1->src[1]->ne[0] != op->ne[0]) { break; } - if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) { + if (!ggml_is_contiguous_rows(f1->src[1])) { break; } - if (ops[n_fuse + 1]->type != GGML_TYPE_F32) { + if (f1->type != GGML_TYPE_F32) { break; } - //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; + //ctx->fuse_cnt[f1->op]++; - bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); + bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]); - args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1]; - args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2]; - args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3]; + args.nef1[n_fuse + 1] = f1->src[1]->ne[1]; + args.nef2[n_fuse + 1] = f1->src[1]->ne[2]; + args.nef3[n_fuse + 1] = f1->src[1]->ne[3]; - args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1]; - args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2]; - args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3]; + args.nbf1[n_fuse + 1] = f1->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = f1->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = f1->src[1]->nb[3]; } ++n_fuse; @@ -2546,10 +2581,10 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { } if (n_fuse > 1) { - bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); + bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); for (int i = 1; i < n_fuse; ++i) { - if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { + if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { ggml_metal_op_concurrency_reset(ctx); break; @@ -2585,8 +2620,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2681,8 +2715,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2752,8 +2785,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2798,8 +2830,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2852,8 +2883,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2897,8 +2927,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2944,8 +2973,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2985,8 +3013,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -3020,8 +3047,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -3060,8 +3086,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -3103,8 +3128,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index a1151f881b372..8df4c72e7c8cb 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -22,6 +22,8 @@ ggml_metal_op_t ggml_metal_op_init( void ggml_metal_op_free(ggml_metal_op_t ctx); +int ggml_metal_op_n_nodes(ggml_metal_op_t ctx); + int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx); //