@@ -103,6 +103,8 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
103103struct ggml_backend_vk_context;
104104
105105#define MAX_PARAMETER_COUNT 8
106+ // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
107+ #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2)
106108
107109struct vk_pipeline_struct {
108110 std::string name;
@@ -368,6 +370,7 @@ struct vk_device_struct {
368370 bool float_controls_rte_fp16;
369371 bool subgroup_add;
370372 bool subgroup_shuffle;
373+ bool multi_add;
371374
372375 bool integer_dot_product;
373376
@@ -449,6 +452,9 @@ struct vk_device_struct {
449452 vk_pipeline pipeline_div[2][2][2];
450453 vk_pipeline pipeline_div_norepeat[2][2][2];
451454
455+ // indexed by num_additional_fused_ops == num_adds - 1
456+ vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
457+
452458 vk_pipeline pipeline_add_id_f32;
453459
454460 vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -800,6 +806,14 @@ struct vk_op_binary_push_constants {
800806 float param1; float param2; int32_t param3;
801807};
802808
809+ struct vk_op_multi_add_push_constants {
810+ // shape for dst
811+ uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
812+
813+ // strides for srcs+dst
814+ uint32_t nb[8][4];
815+ };
816+
803817struct vk_op_add_id_push_constants {
804818 uint32_t ne0;
805819 uint32_t ne1;
@@ -3011,6 +3025,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
30113025 CREATE_BINARY(div, _norepeat, {1})
30123026#undef CREATE_BINARY
30133027
3028+ if (device->multi_add) {
3029+ for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
3030+ ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3031+ }
3032+ }
3033+
30143034 ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
30153035
30163036 ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3409,6 +3429,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
34093429 }
34103430 device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
34113431
3432+ device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
3433+ device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
3434+ getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
3435+
34123436 device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
34133437 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
34143438
@@ -6892,6 +6916,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
68926916 switch (op) {
68936917 case GGML_OP_ADD:
68946918 {
6919+ if (ctx->num_additional_fused_ops > 0) {
6920+ return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6921+ }
68956922 auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
68966923 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
68976924 }
@@ -7739,6 +7766,107 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
77397766 }, dryrun);
77407767}
77417768
7769+ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
7770+ const ggml_tensor *first_node = cgraph->nodes[node_idx];
7771+ const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
7772+
7773+ // Make a list of all the tensors used by the op.
7774+ // Last element of the list is the dest tensor.
7775+ const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
7776+ uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
7777+ uint32_t num_tensors = num_srcs + 1;
7778+ GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);
7779+
7780+ tensors[0] = first_node->src[0];
7781+ tensors[1] = first_node->src[1];
7782+ for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
7783+ // check whether the previous result is src[0] or src[1]
7784+ if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
7785+ tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
7786+ } else {
7787+ tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
7788+ }
7789+ }
7790+ tensors[num_srcs] = dst;
7791+
7792+ vk_op_multi_add_push_constants pc;
7793+ pc.ne20 = (uint32_t)dst->ne[0];
7794+ pc.ne21 = (uint32_t)dst->ne[1];
7795+ pc.ne22 = (uint32_t)dst->ne[2];
7796+ pc.ne23 = (uint32_t)dst->ne[3];
7797+
7798+ for (uint32_t i = 0; i < num_tensors; ++i) {
7799+ const ggml_tensor *t = tensors[i];
7800+ pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
7801+ pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
7802+ pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
7803+ pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
7804+ }
7805+
7806+ vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
7807+
7808+ if (pipeline == nullptr) {
7809+ std::cerr << "ggml_vulkan: Error: Missing multi_add";
7810+ GGML_ABORT("fatal error");
7811+ }
7812+
7813+ if (dryrun) {
7814+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
7815+ return;
7816+ }
7817+
7818+ ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
7819+ vk_buffer buf[MAX_PARAMETER_COUNT];
7820+ size_t offset[MAX_PARAMETER_COUNT];
7821+ bool uma[MAX_PARAMETER_COUNT];
7822+
7823+ for (uint32_t i = 0; i < num_tensors; ++i) {
7824+ buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
7825+ buf[i] = nullptr;
7826+ offset[i] = 0;
7827+ uma[i] = false;
7828+
7829+ if (ctx->device->uma) {
7830+ ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
7831+ uma[i] = buf[i] != nullptr;
7832+ }
7833+ if (!uma[i]) {
7834+ buf[i] = buf_ctx[i]->dev_buffer;
7835+ offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
7836+ }
7837+ GGML_ASSERT(buf[i] != nullptr);
7838+ }
7839+ // If any remaining descriptors are unused, just point them at src[0]
7840+ for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
7841+ buf[i] = buf[0];
7842+ offset[i] = 0;
7843+ }
7844+
7845+ std::array<uint32_t, 3> elements;
7846+
7847+ uint32_t ne = ggml_nelements(dst);
7848+ if (ne > 262144) {
7849+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
7850+ } else if (ne > 512) {
7851+ elements = { 512, CEIL_DIV(ne, 512), 1 };
7852+ } else {
7853+ elements = { ne, 1, 1 };
7854+ }
7855+
7856+ ggml_vk_sync_buffers(subctx);
7857+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7858+ {
7859+ vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
7860+ vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
7861+ vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
7862+ vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
7863+ vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
7864+ vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
7865+ vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
7866+ vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
7867+ }, pc, elements);
7868+ }
7869+
77427870static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
77437871 const uint32_t src0_type_size = ggml_type_size(src0->type);
77447872 const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -9692,8 +9820,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96929820
96939821 break;
96949822 case GGML_OP_ADD:
9695- ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
9696-
9823+ if (ctx->num_additional_fused_ops) {
9824+ ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun);
9825+ } else {
9826+ ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
9827+ }
96979828 break;
96989829 case GGML_OP_SUB:
96999830 ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -10570,6 +10701,58 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
1057010701 return true;
1057110702}
1057210703
10704+ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
10705+
10706+ if (!ctx->device->multi_add) {
10707+ return 0;
10708+ }
10709+
10710+ const ggml_tensor *first_node = cgraph->nodes[node_idx];
10711+ if (first_node->op != GGML_OP_ADD) {
10712+ return 0;
10713+ }
10714+
10715+ int32_t num_adds = 1;
10716+ while (node_idx + num_adds < cgraph->n_nodes &&
10717+ cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
10718+ num_adds < MAX_FUSED_ADDS) {
10719+ num_adds++;
10720+ }
10721+
10722+ // The shader currently requires same shapes (but different strides are allowed),
10723+ // everything f32, and no misalignment
10724+ for (int32_t i = 0; i < num_adds; ++i) {
10725+ const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
10726+ if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
10727+ !ggml_are_same_shape(first_node, next_node->src[1]) ||
10728+ next_node->type != GGML_TYPE_F32 ||
10729+ next_node->src[0]->type != GGML_TYPE_F32 ||
10730+ next_node->src[1]->type != GGML_TYPE_F32 ||
10731+ get_misalign_bytes(ctx, next_node) ||
10732+ get_misalign_bytes(ctx, next_node->src[0]) ||
10733+ get_misalign_bytes(ctx, next_node->src[1])) {
10734+ num_adds = i;
10735+ }
10736+ }
10737+
10738+ // Verify we can fuse these
10739+ ggml_op adds[MAX_FUSED_ADDS];
10740+ for (int32_t i = 0; i < num_adds; ++i) {
10741+ adds[i] = GGML_OP_ADD;
10742+ }
10743+
10744+ // decrease num_adds if they can't all be fused
10745+ while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
10746+ num_adds--;
10747+ }
10748+
10749+ // a single add is not "fused", so just return zero
10750+ if (num_adds == 1) {
10751+ return 0;
10752+ }
10753+ return num_adds;
10754+ }
10755+
1057310756static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
1057410757 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
1057510758 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -10583,8 +10766,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1058310766
1058410767 uint64_t total_mat_mul_bytes = 0;
1058510768 for (int i = 0; i < cgraph->n_nodes; i++) {
10586- if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10587- ctx->num_additional_fused_ops = 1;
10769+ if (!ctx->device->disable_fusion) {
10770+ uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
10771+ if (num_adds) {
10772+ ctx->num_additional_fused_ops = num_adds - 1;
10773+ } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10774+ ctx->num_additional_fused_ops = 1;
10775+ }
1058810776 }
1058910777 ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
1059010778 if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
@@ -10659,8 +10847,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1065910847 mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
1066010848 }
1066110849
10662- if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10663- ctx->num_additional_fused_ops = 1;
10850+ if (!ctx->device->disable_fusion) {
10851+ uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
10852+ if (num_adds) {
10853+ ctx->num_additional_fused_ops = num_adds - 1;
10854+ } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10855+ ctx->num_additional_fused_ops = 1;
10856+ }
1066410857 }
1066510858
1066610859 // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
0 commit comments