Skip to content

Commit 611f419

Browse files
authored
vulkan: optimize rms_norm, and allow the work to spread across multiple SMs (ggml-org#15281)
* vulkan: optimize rms_norm, and allow the work to spread across multiple SMs There are really two parts to this change: (1) Some optimizations similar to what we have in soft_max, to unroll with different numbers of iterations. (2) A fusion optimization where we detect add followed by rms_norm, and make the add shader atomically accumulate the values^2 into memory. Then the rms_norm shader can just load that sum. This allows the rms_norm to be parallelized across multiple workgroups, it just becomes a simple per-element multiply. The fusion optimization is currently only applied when the rms_norm is on a single vector. This previously always ran on a single SM. It could apply more broadly, but when there are other dimensions the work can already spread across SMs, and there would be some complexity to tracking multiple atomic sums. * Change add+rms_norm optimization to write out an array of partial sums rather than using atomic add, to make it deterministic. The rms_norm shader fetches a subgroup's worth in parallel and uses subgroupAdd to add them up. * complete rebase against fused adds - multi_add shader can also compute partial sums * fix validation errors * disable add_rms_fusion for Intel due to possible driver bug * resolve against ggml-org#15489, sync after clearing partial sums
1 parent b1afcab commit 611f419

File tree

7 files changed

+380
-51
lines changed

7 files changed

+380
-51
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 164 additions & 31 deletions
Large diffs are not rendered by default.
Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,69 @@
11
#version 450
22

33
#extension GL_EXT_shader_16bit_storage : require
4+
#if ADD_RMS
5+
#extension GL_KHR_shader_subgroup_arithmetic : enable
6+
#extension GL_KHR_shader_subgroup_basic : enable
7+
#endif
48

59
#include "types.comp"
610
#include "generic_binary_head.comp"
711

812
const uint num_threads = 256;
913

14+
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
15+
1016
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
1117

18+
#if ADD_RMS
19+
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
20+
shared FLOAT_TYPE sumsh[num_threads];
21+
#endif
22+
1223
void main() {
1324
uint idx = get_idx();
25+
uint orig_idx = idx;
1426

1527
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
1628
const uint num_iter = 2;
1729

30+
FLOAT_TYPE sum_sq = 0;
31+
1832
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
1933
if (idx >= p.ne) {
2034
continue;
2135
}
2236
uint i00, i01, i02, i03;
2337
get_indices(idx, i00, i01, i02, i03);
2438

25-
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
39+
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
40+
sum_sq += sum*sum;
41+
42+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
2643

2744
idx += num_threads;
2845
}
46+
47+
#if ADD_RMS
48+
if (p.param3 != 0) {
49+
// reduce the sum within each subgroup, then across subgroups
50+
const uint NumSubgroups = num_threads / gl_SubgroupSize;
51+
sum_sq = subgroupAdd(sum_sq);
52+
if (gl_SubgroupInvocationID == 0) {
53+
sumsh[gl_SubgroupID] = sum_sq;
54+
}
55+
barrier();
56+
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
57+
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
58+
sum_sq += sumsh[gl_SubgroupID + s];
59+
sumsh[gl_SubgroupID] = sum_sq;
60+
}
61+
barrier();
62+
}
63+
64+
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
65+
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
66+
}
67+
}
68+
#endif
2969
}

ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#extension GL_EXT_shader_16bit_storage : require
44
#extension GL_EXT_nonuniform_qualifier : enable
55
#extension GL_EXT_control_flow_attributes : require
6+
#if ADD_RMS
7+
#extension GL_KHR_shader_subgroup_arithmetic : enable
8+
#extension GL_KHR_shader_subgroup_basic : enable
9+
#endif
610

711
#include "rte.comp"
812
#include "types.comp"
@@ -14,12 +18,16 @@ layout (push_constant) uniform parameter2
1418
uint ne20; uint ne21; uint ne22; uint ne23;
1519

1620
// strides for srcs+dst
17-
uint nb[8][4];
21+
uint nb[12][4];
22+
23+
uint rms_partials;
1824
} p;
1925

2026
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
2127
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
2228

29+
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
30+
2331
layout(constant_id = 0) const uint num_srcs = 2;
2432

2533
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
@@ -42,14 +50,22 @@ const uint num_threads = 256;
4250

4351
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
4452

53+
#if ADD_RMS
54+
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
55+
shared FLOAT_TYPE sumsh[num_threads];
56+
#endif
57+
4558
void main() {
4659
uint idx = get_idx();
60+
uint orig_idx = idx;
4761

4862
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
4963

5064
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
5165
const uint num_iter = 2;
5266

67+
FLOAT_TYPE sum_sq = 0;
68+
5369
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
5470
if (idx >= ne) {
5571
continue;
@@ -61,8 +77,32 @@ void main() {
6177
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
6278
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
6379
}
80+
sum_sq += sum*sum;
6481
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
6582

6683
idx += num_threads;
6784
}
85+
86+
#if ADD_RMS
87+
if (p.rms_partials != 0) {
88+
// reduce the sum within each subgroup, then across subgroups
89+
const uint NumSubgroups = num_threads / gl_SubgroupSize;
90+
sum_sq = subgroupAdd(sum_sq);
91+
if (gl_SubgroupInvocationID == 0) {
92+
sumsh[gl_SubgroupID] = sum_sq;
93+
}
94+
barrier();
95+
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
96+
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
97+
sum_sq += sumsh[gl_SubgroupID + s];
98+
sumsh[gl_SubgroupID] = sum_sq;
99+
}
100+
barrier();
101+
}
102+
103+
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
104+
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
105+
}
106+
}
107+
#endif
68108
}

ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false;
1010

1111
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
1212

13-
shared FLOAT_TYPE sum[BLOCK_SIZE];
13+
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
1414

15-
void main() {
15+
void rms_norm(uint num_iters) {
1616
const uint ncols = p.ne00;
1717
const uint nrows = gl_NumWorkGroups.x;
1818
const uint nchannels = gl_NumWorkGroups.y;
@@ -30,38 +30,76 @@ void main() {
3030
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
3131
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
3232

33-
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
33+
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
3434

35-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
36-
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
37-
sum[tid] += xi * xi;
35+
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
36+
FLOAT_TYPE xi = FLOAT_TYPE(0);
37+
if (col < ncols) {
38+
xi = FLOAT_TYPE(data_a[a_offset + col]);
39+
}
40+
sum += xi * xi;
3841
}
3942

43+
sumsh[tid] = sum;
4044
// sum up partial sums and write back result
4145
barrier();
4246
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
4347
if (tid < s) {
44-
sum[tid] += sum[tid + s];
48+
sum += sumsh[tid + s];
49+
sumsh[tid] = sum;
4550
}
4651
barrier();
4752
}
53+
sum = sumsh[0];
4854

49-
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
55+
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
5056
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
5157

5258
if (do_multiply) {
5359
if (ncols > p.ne10) {
54-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
60+
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
61+
if (col >= ncols) {
62+
continue;
63+
}
5564
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
5665
}
5766
} else {
58-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
67+
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
68+
if (col >= ncols) {
69+
continue;
70+
}
5971
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
6072
}
6173
}
6274
} else {
63-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
75+
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
76+
if (col >= ncols) {
77+
continue;
78+
}
6479
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
6580
}
6681
}
6782
}
83+
84+
void main() {
85+
// instantiate the rms_norm function for several different
86+
// dimensions, to allow loop unrolling
87+
uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
88+
if (num_blocks > 32) {
89+
rms_norm(num_blocks);
90+
} else if (num_blocks > 16) {
91+
rms_norm(32);
92+
} else if (num_blocks > 8) {
93+
rms_norm(16);
94+
} else if (num_blocks > 4) {
95+
rms_norm(8);
96+
} else if (num_blocks == 4) {
97+
rms_norm(4);
98+
} else if (num_blocks == 3) {
99+
rms_norm(3);
100+
} else if (num_blocks == 2) {
101+
rms_norm(2);
102+
} else if (num_blocks == 1) {
103+
rms_norm(1);
104+
}
105+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#version 450
2+
3+
#include "generic_binary_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
#extension GL_KHR_shader_subgroup_arithmetic : enable
8+
#extension GL_KHR_shader_subgroup_basic : enable
9+
10+
#define BLOCK_SIZE 128
11+
12+
layout (constant_id = 1) const bool do_multiply = false;
13+
14+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
15+
16+
layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};
17+
18+
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
19+
20+
void main() {
21+
const uint ncols = p.ne00;
22+
const uint nrows = gl_NumWorkGroups.x;
23+
const uint nchannels = gl_NumWorkGroups.y;
24+
25+
const uint row = 0;
26+
const uint channel = gl_WorkGroupID.y;
27+
const uint samp = gl_WorkGroupID.z;
28+
// The work is split across multiple workgroups in the x dimension. Each invocation
29+
// processes one element
30+
const uint tid = gl_GlobalInvocationID.x;
31+
32+
const uint stride_row = p.nb01;
33+
const uint stride_channel = p.nb02;
34+
const uint stride_sample = p.nb03;
35+
36+
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
37+
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
38+
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
39+
40+
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
41+
42+
uint32_t num_partials = p.param3;
43+
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
44+
sum += partial_sums[i];
45+
}
46+
sum = subgroupAdd(sum);
47+
48+
uint col = tid;
49+
if (col >= ncols) {
50+
return;
51+
}
52+
53+
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
54+
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
55+
56+
if (do_multiply) {
57+
if (ncols > p.ne10) {
58+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
59+
} else {
60+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
61+
}
62+
} else {
63+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
64+
}
65+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ void process_shaders() {
503503
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
504504
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
505505
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
506+
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
506507
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
507508
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
508509

@@ -538,13 +539,15 @@ void process_shaders() {
538539
s += std::string(dst_f16 ? "_f16" : "_f32");
539540
return s;
540541
};
541-
for (std::string op : {"add", "sub", "mul", "div"}) {
542+
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
542543
for (auto src0_f16 : {false, true}) {
543544
for (auto src1_f16 : {false, true}) {
544545
for (auto dst_f16 : {false, true}) {
545546
for (auto rte : {false, true}) {
547+
auto source = op == "add_rms" ? std::string("add") : op;
546548
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
547-
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
549+
auto add_rms = op == "add_rms" ? "1" : "0";
550+
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
548551
}
549552
}
550553
}
@@ -687,7 +690,8 @@ void process_shaders() {
687690

688691
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
689692

690-
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
693+
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
694+
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
691695

692696
for (auto &c : compiles) {
693697
c.wait();
@@ -745,7 +749,7 @@ void write_output_files() {
745749
}
746750

747751
std::string suffixes[2] = {"_f32", "_f16"};
748-
for (const char *op : {"add", "sub", "mul", "div"}) {
752+
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
749753
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
750754
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
751755
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";

0 commit comments

Comments
 (0)