@@ -431,6 +431,7 @@ struct vk_device_struct {
431431 vk_pipeline pipeline_norm_f32;
432432 vk_pipeline pipeline_group_norm_f32;
433433 vk_pipeline pipeline_rms_norm_f32;
434+ vk_pipeline pipeline_fused_rms_norm_f32;
434435 vk_pipeline pipeline_rms_norm_back_f32;
435436
436437 // [src/dst 0=fp32,1=fp16]
@@ -2653,6 +2654,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
26532654 ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26542655 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26552656 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2657+ ggml_vk_create_pipeline(device, device->pipeline_fused_rms_norm_f32, "fused_rms_norm_f32", fused_rms_norm_f32_len, fused_rms_norm_f32_data, "main", 3, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
26562658 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26572659
26582660 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -6381,6 +6383,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
63816383 return ctx->device->pipeline_rms_norm_f32;
63826384 }
63836385 return nullptr;
6386+ case GGML_OP_FUSED_RMS_NORM:
6387+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6388+ return ctx->device->pipeline_fused_rms_norm_f32;
6389+ }
6390+ return nullptr;
63846391 case GGML_OP_RMS_NORM_BACK:
63856392 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
63866393 return ctx->device->pipeline_rms_norm_back_f32;
@@ -6521,6 +6528,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
65216528 case GGML_OP_REPEAT_BACK:
65226529 case GGML_OP_ROPE:
65236530 case GGML_OP_RMS_NORM:
6531+ case GGML_OP_FUSED_RMS_NORM:
65246532 case GGML_OP_IM2COL:
65256533 return true;
65266534 default:
@@ -6751,6 +6759,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
67516759 elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
67526760 break;
67536761
6762+ case GGML_OP_FUSED_RMS_NORM:
6763+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6764+ break;
6765+
67546766 case GGML_OP_SUM:
67556767 // We use GGML_OP_SUM_ROWS with 1 row.
67566768 elements = { 1, 1, 1 };
@@ -7173,6 +7185,24 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
71737185 }, dryrun);
71747186}
71757187
7188+ static void ggml_vk_fused_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7189+ float * op_params = (float *)dst->op_params;
7190+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7191+ const uint32_t src1_type_size = ggml_type_size(src1->type);
7192+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7193+ GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
7194+ GGML_ASSERT(src1->ne[0] == src0->ne[0]);
7195+
7196+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_FUSED_RMS_NORM, {
7197+ (uint32_t)ggml_nelements(src0),
7198+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7199+ (uint32_t)src1->ne[0], 1u, 1u, 1u, (uint32_t)src1->nb[0] / src1_type_size, 0u, 0u, 0u,
7200+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7201+ 0,
7202+ op_params[0], 0.0f, 0,
7203+ }, dryrun);
7204+ }
7205+
71767206static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
71777207 float * op_params = (float *)dst->op_params;
71787208 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
@@ -8386,6 +8416,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
83868416 case GGML_OP_NORM:
83878417 case GGML_OP_GROUP_NORM:
83888418 case GGML_OP_RMS_NORM:
8419+ case GGML_OP_FUSED_RMS_NORM:
83898420 case GGML_OP_RMS_NORM_BACK:
83908421 case GGML_OP_DIAG_MASK_INF:
83918422 case GGML_OP_SOFT_MAX:
@@ -8444,6 +8475,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
84448475 case GGML_OP_NORM:
84458476 case GGML_OP_GROUP_NORM:
84468477 case GGML_OP_RMS_NORM:
8478+ case GGML_OP_FUSED_RMS_NORM:
84478479 case GGML_OP_RMS_NORM_BACK:
84488480 case GGML_OP_UNARY:
84498481 case GGML_OP_DIAG_MASK_INF:
@@ -8550,6 +8582,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
85508582 case GGML_OP_RMS_NORM:
85518583 ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
85528584
8585+ break;
8586+ case GGML_OP_FUSED_RMS_NORM:
8587+ ggml_vk_fused_rms_norm(ctx, compute_ctx, src0, src1, node, dryrun);
8588+
85538589 break;
85548590 case GGML_OP_RMS_NORM_BACK:
85558591 ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8703,6 +8739,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
87038739 case GGML_OP_NORM:
87048740 case GGML_OP_GROUP_NORM:
87058741 case GGML_OP_RMS_NORM:
8742+ case GGML_OP_FUSED_RMS_NORM:
87068743 case GGML_OP_RMS_NORM_BACK:
87078744 case GGML_OP_DIAG_MASK_INF:
87088745 case GGML_OP_SOFT_MAX:
@@ -9625,6 +9662,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
96259662 case GGML_OP_PERMUTE:
96269663 case GGML_OP_TRANSPOSE:
96279664 case GGML_OP_RMS_NORM:
9665+ case GGML_OP_FUSED_RMS_NORM:
96289666 return true;
96299667 case GGML_OP_NORM:
96309668 case GGML_OP_GROUP_NORM:
@@ -10064,6 +10102,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1006410102 tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
1006510103 } else if (tensor->op == GGML_OP_RMS_NORM) {
1006610104 tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
10105+ } else if (tensor->op == GGML_OP_FUSED_RMS_NORM) {
10106+ tensor_clone = ggml_fused_rms_norm(ggml_ctx, src_clone[0], src_clone[1], *(float *)tensor->op_params);
1006710107 } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
1006810108 const float eps = ((float *) tensor->op_params)[0];
1006910109 tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
0 commit comments