Skip to content

Commit 9534461

Browse files
ikawrakowIwan Kawrakow
andauthored
Vulkan: fused rms norm (#577)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent db8dee5 commit 9534461

File tree

4 files changed

+96
-6
lines changed

4 files changed

+96
-6
lines changed

ggml/src/ggml-vulkan.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ struct vk_device_struct {
431431
vk_pipeline pipeline_norm_f32;
432432
vk_pipeline pipeline_group_norm_f32;
433433
vk_pipeline pipeline_rms_norm_f32;
434+
vk_pipeline pipeline_fused_rms_norm_f32;
434435
vk_pipeline pipeline_rms_norm_back_f32;
435436

436437
// [src/dst 0=fp32,1=fp16]
@@ -2653,6 +2654,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
26532654
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26542655
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26552656
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2657+
ggml_vk_create_pipeline(device, device->pipeline_fused_rms_norm_f32, "fused_rms_norm_f32", fused_rms_norm_f32_len, fused_rms_norm_f32_data, "main", 3, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
26562658
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26572659

26582660
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -6381,6 +6383,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
63816383
return ctx->device->pipeline_rms_norm_f32;
63826384
}
63836385
return nullptr;
6386+
case GGML_OP_FUSED_RMS_NORM:
6387+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6388+
return ctx->device->pipeline_fused_rms_norm_f32;
6389+
}
6390+
return nullptr;
63846391
case GGML_OP_RMS_NORM_BACK:
63856392
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
63866393
return ctx->device->pipeline_rms_norm_back_f32;
@@ -6521,6 +6528,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
65216528
case GGML_OP_REPEAT_BACK:
65226529
case GGML_OP_ROPE:
65236530
case GGML_OP_RMS_NORM:
6531+
case GGML_OP_FUSED_RMS_NORM:
65246532
case GGML_OP_IM2COL:
65256533
return true;
65266534
default:
@@ -6751,6 +6759,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
67516759
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
67526760
break;
67536761

6762+
case GGML_OP_FUSED_RMS_NORM:
6763+
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6764+
break;
6765+
67546766
case GGML_OP_SUM:
67556767
// We use GGML_OP_SUM_ROWS with 1 row.
67566768
elements = { 1, 1, 1 };
@@ -7173,6 +7185,24 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
71737185
}, dryrun);
71747186
}
71757187

7188+
static void ggml_vk_fused_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7189+
float * op_params = (float *)dst->op_params;
7190+
const uint32_t src0_type_size = ggml_type_size(src0->type);
7191+
const uint32_t src1_type_size = ggml_type_size(src1->type);
7192+
const uint32_t dst_type_size = ggml_type_size(dst->type);
7193+
GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
7194+
GGML_ASSERT(src1->ne[0] == src0->ne[0]);
7195+
7196+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_FUSED_RMS_NORM, {
7197+
(uint32_t)ggml_nelements(src0),
7198+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7199+
(uint32_t)src1->ne[0], 1u, 1u, 1u, (uint32_t)src1->nb[0] / src1_type_size, 0u, 0u, 0u,
7200+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7201+
0,
7202+
op_params[0], 0.0f, 0,
7203+
}, dryrun);
7204+
}
7205+
71767206
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
71777207
float * op_params = (float *)dst->op_params;
71787208
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
@@ -8386,6 +8416,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
83868416
case GGML_OP_NORM:
83878417
case GGML_OP_GROUP_NORM:
83888418
case GGML_OP_RMS_NORM:
8419+
case GGML_OP_FUSED_RMS_NORM:
83898420
case GGML_OP_RMS_NORM_BACK:
83908421
case GGML_OP_DIAG_MASK_INF:
83918422
case GGML_OP_SOFT_MAX:
@@ -8444,6 +8475,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
84448475
case GGML_OP_NORM:
84458476
case GGML_OP_GROUP_NORM:
84468477
case GGML_OP_RMS_NORM:
8478+
case GGML_OP_FUSED_RMS_NORM:
84478479
case GGML_OP_RMS_NORM_BACK:
84488480
case GGML_OP_UNARY:
84498481
case GGML_OP_DIAG_MASK_INF:
@@ -8550,6 +8582,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
85508582
case GGML_OP_RMS_NORM:
85518583
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
85528584

8585+
break;
8586+
case GGML_OP_FUSED_RMS_NORM:
8587+
ggml_vk_fused_rms_norm(ctx, compute_ctx, src0, src1, node, dryrun);
8588+
85538589
break;
85548590
case GGML_OP_RMS_NORM_BACK:
85558591
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8703,6 +8739,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
87038739
case GGML_OP_NORM:
87048740
case GGML_OP_GROUP_NORM:
87058741
case GGML_OP_RMS_NORM:
8742+
case GGML_OP_FUSED_RMS_NORM:
87068743
case GGML_OP_RMS_NORM_BACK:
87078744
case GGML_OP_DIAG_MASK_INF:
87088745
case GGML_OP_SOFT_MAX:
@@ -9625,6 +9662,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
96259662
case GGML_OP_PERMUTE:
96269663
case GGML_OP_TRANSPOSE:
96279664
case GGML_OP_RMS_NORM:
9665+
case GGML_OP_FUSED_RMS_NORM:
96289666
return true;
96299667
case GGML_OP_NORM:
96309668
case GGML_OP_GROUP_NORM:
@@ -10064,6 +10102,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1006410102
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
1006510103
} else if (tensor->op == GGML_OP_RMS_NORM) {
1006610104
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
10105+
} else if (tensor->op == GGML_OP_FUSED_RMS_NORM) {
10106+
tensor_clone = ggml_fused_rms_norm(ggml_ctx, src_clone[0], src_clone[1], *(float *)tensor->op_params);
1006710107
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
1006810108
const float eps = ((float *) tensor->op_params)[0];
1006910109
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#version 450
2+
3+
#include "generic_binary_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
#define BLOCK_SIZE 512
8+
9+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10+
11+
shared FLOAT_TYPE sum[BLOCK_SIZE];
12+
13+
void main() {
14+
const uint ncols = p.ne00;
15+
const uint nrows = gl_NumWorkGroups.x;
16+
const uint nchannels = gl_NumWorkGroups.y;
17+
18+
const uint row = gl_WorkGroupID.x;
19+
const uint channel = gl_WorkGroupID.y;
20+
const uint samp = gl_WorkGroupID.z;
21+
const uint tid = gl_LocalInvocationID.x;
22+
23+
const uint stride_row_a = p.nb01;
24+
const uint stride_channel_a = p.nb02;
25+
const uint stride_sample_a = p.nb03;
26+
27+
uint32_t a_offset = samp*stride_sample_a + channel*stride_channel_a + row*stride_row_a;
28+
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
29+
30+
FLOAT_TYPE sumf = FLOAT_TYPE(0.0f);
31+
32+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
33+
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
34+
sumf += xi * xi;
35+
}
36+
37+
sum[tid] = sumf;
38+
39+
// sum up partial sums and write back result
40+
barrier();
41+
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
42+
if (tid < s) {
43+
sum[tid] += sum[tid + s];
44+
}
45+
barrier();
46+
}
47+
48+
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
49+
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
50+
51+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
52+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[col]));
53+
}
54+
}

ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ void process_shaders() {
498498
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
499499
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
500500
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
501+
string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
501502
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
502503

503504
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

src/llama.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9597,12 +9597,7 @@ static struct ggml_tensor * llm_build_norm(
95979597
const llm_build_cb & cb,
95989598
int il, float scale_eps = 1) {
95999599

9600-
#ifdef GGML_USE_VULKAN
9601-
constexpr bool use_fused_rms_norm = false;
9602-
#else
9603-
constexpr bool use_fused_rms_norm = true;
9604-
#endif
9605-
if (use_fused_rms_norm && type == LLM_NORM_RMS && mw) {
9600+
if (type == LLM_NORM_RMS && mw) {
96069601
cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps);
96079602
if (mb) {
96089603
cb(cur, "fused_norm", il);

0 commit comments

Comments
 (0)