@@ -425,6 +425,7 @@ struct vk_device_struct {
425425 vk_pipeline pipeline_norm_f32;
426426 vk_pipeline pipeline_group_norm_f32;
427427 vk_pipeline pipeline_rms_norm_f32;
428+ vk_pipeline pipeline_rms_norm_mul_f32;
428429 vk_pipeline pipeline_rms_norm_back_f32;
429430 vk_pipeline pipeline_l2_norm_f32;
430431
@@ -978,6 +979,10 @@ struct ggml_backend_vk_context {
978979
979980 vk_command_pool compute_cmd_pool;
980981 vk_command_pool transfer_cmd_pool;
982+
983+ // number of additional consecutive nodes that are being fused with the
984+ // node currently being processed
985+ bool num_additional_fused_ops {};
981986};
982987
983988static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
26552660
26562661 ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26572662 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2658- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2663+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2664+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
26592665 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26602666 ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26612667
@@ -6418,7 +6424,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
64186424 return nullptr;
64196425 case GGML_OP_RMS_NORM:
64206426 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6421- return ctx->device->pipeline_rms_norm_f32;
6427+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx-> device->pipeline_rms_norm_f32;
64226428 }
64236429 return nullptr;
64246430 case GGML_OP_RMS_NORM_BACK:
@@ -7518,18 +7524,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
75187524 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
75197525}
75207526
7521- static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7527+ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
75227528 float * op_params = (float *)dst->op_params;
75237529 const uint32_t src0_type_size = ggml_type_size(src0->type);
7530+ const uint32_t src1_type_size = ggml_type_size(src1->type);
75247531 const uint32_t dst_type_size = ggml_type_size(dst->type);
75257532
7526- ggml_vk_op_f32<vk_op_unary_push_constants >(ctx, subctx, src0, nullptr , nullptr, dst, GGML_OP_RMS_NORM, {
7533+ ggml_vk_op_f32<vk_op_binary_push_constants >(ctx, subctx, src0, src1 , nullptr, dst, GGML_OP_RMS_NORM, {
75277534 (uint32_t)ggml_nelements(src0),
7528- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7529- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7535+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7536+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7537+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
75307538 0,
7531- op_params[0], 0.0f,
7532- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7539+ op_params[0], 0.0f, 0,
75337540 }, dryrun);
75347541}
75357542
@@ -8724,7 +8731,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
87248731
87258732// Returns true if node has enqueued work into the queue, false otherwise
87268733// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8727- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8734+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
87288735 if (ggml_is_empty(node) || !node->buffer) {
87298736 return false;
87308737 }
@@ -8962,8 +8969,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
89628969
89638970 break;
89648971 case GGML_OP_RMS_NORM:
8965- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8966-
8972+ if (ctx->num_additional_fused_ops > 0) {
8973+ // fused rms_norm + mul
8974+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
8975+ ggml_vk_rms_norm(ctx, compute_ctx, src0, mul->src[1], mul, dryrun);
8976+ } else {
8977+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
8978+ }
89678979 break;
89688980 case GGML_OP_RMS_NORM_BACK:
89698981 ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9685,6 +9697,34 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
96859697 return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
96869698}
96879699
9700+ // Returns true if nodes [i, i+1] are fusable RMS_NORM + MUL.
9701+ bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
9702+ ggml_tensor *norm = cgraph->nodes[i];
9703+
9704+ if (norm->op != GGML_OP_RMS_NORM || norm->use_count != 1) {
9705+ return false;
9706+ }
9707+ // if norm is a view, some other node might be using the intermediate result
9708+ // view the view source.
9709+ if (norm->view_src) {
9710+ return false;
9711+ }
9712+
9713+ if (i + 1 >= cgraph->n_nodes) {
9714+ return false;
9715+ }
9716+ ggml_tensor *mul = cgraph->nodes[i + 1];
9717+ if (mul->op != GGML_OP_MUL || mul->src[0] != norm) {
9718+ return false;
9719+ }
9720+
9721+ // Since norm is the first operand of mul, it must be the same shape
9722+ GGML_ASSERT(ggml_are_same_shape(mul, norm));
9723+
9724+ // XXX TODO: Do we need a way to indicate that the user doesn't need the intermediate result?
9725+ return true;
9726+ }
9727+
96889728static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
96899729 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
96909730 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9698,10 +9738,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
96989738
96999739 uint64_t total_mat_mul_bytes = 0;
97009740 for (int i = 0; i < cgraph->n_nodes; i++) {
9701- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
9741+ if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9742+ ctx->num_additional_fused_ops = 1;
9743+ }
9744+ ggml_vk_build_graph(ctx, cgraph, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
97029745 if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
97039746 total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
97049747 }
9748+ i += ctx->num_additional_fused_ops;
9749+ ctx->num_additional_fused_ops = 0;
97059750 }
97069751 if (ctx->device->need_compiles) {
97079752 ggml_vk_load_shaders(ctx->device);
@@ -9763,14 +9808,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97639808 mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
97649809 }
97659810
9811+ if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9812+ ctx->num_additional_fused_ops = 1;
9813+ }
9814+
97669815 // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
97679816 bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
97689817 bool submit = (submitted_nodes >= nodes_per_submit) ||
97699818 (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9770- (i == last_node) ||
9819+ (i + ctx->num_additional_fused_ops == last_node) ||
97719820 (almost_ready && !ctx->almost_ready_fence_pending);
97729821
9773- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
9822+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, cgraph ->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
97749823
97759824 if (vk_perf_logger_enabled) {
97769825 if (ctx->compute_ctx.expired()) {
@@ -9780,7 +9829,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97809829 } else {
97819830 compute_ctx = ctx->compute_ctx.lock();
97829831 }
9783- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
9832+ // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
9833+ for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
9834+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
9835+ }
97849836 }
97859837
97869838 if (enqueued) {
@@ -9802,6 +9854,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
98029854 }
98039855 submit_count++;
98049856 }
9857+ i += ctx->num_additional_fused_ops;
9858+ ctx->num_additional_fused_ops = 0;
98059859 }
98069860
98079861 if (vk_perf_logger_enabled) {
0 commit comments