Skip to content

Commit 303f861

Browse files
authored
vulkan: Multi-pass softmax for large number of cols (#17892)
When the number of cols is large, split each row across multiple workgroups. There are three phases that communicate partial results through temp buffers: (1) compute max partials (2) take max of partials, compute sum(exp(x-max)) partials (3) sum partials, compute scaled result
1 parent 3c6391e commit 303f861

File tree

7 files changed

+331
-2
lines changed

7 files changed

+331
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,11 @@ struct vk_device_struct {
722722
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
723723
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
724724
vk_pipeline pipeline_soft_max_back_f32;
725+
726+
vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
727+
vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
728+
vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
729+
725730
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
726731
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
727732
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -3998,6 +4003,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
39984003
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
39994004
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
40004005

4006+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4007+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4008+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4009+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4010+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4011+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4012+
40014013
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40024014
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40034015
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -10117,7 +10129,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
1011710129
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1011810130
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1011910131

10120-
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
10132+
vk_op_soft_max_push_constants pc {
1012110133
ncols,
1012210134
src1 != nullptr ? nrows_y : (uint32_t)0,
1012310135
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -10128,7 +10140,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
1012810140
n_head_log2,
1012910141
nrows_x,
1013010142
src2 != nullptr
10131-
});
10143+
};
10144+
10145+
if (ncols <= 16384) {
10146+
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
10147+
} else {
10148+
10149+
vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
10150+
vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
10151+
vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
10152+
vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
10153+
10154+
uint32_t elems_per_wg = 128 * 4;
10155+
uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
10156+
size_t tmp_size = num_wgs * nrows_x * sizeof(float);
10157+
10158+
if (ctx->prealloc_size_x < tmp_size) {
10159+
ctx->prealloc_size_x = tmp_size;
10160+
ggml_vk_preallocate_buffers(ctx, subctx);
10161+
}
10162+
if (ctx->prealloc_size_y < tmp_size) {
10163+
ctx->prealloc_size_y = tmp_size;
10164+
ggml_vk_preallocate_buffers(ctx, subctx);
10165+
}
10166+
if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
10167+
ggml_vk_sync_buffers(ctx, subctx);
10168+
}
10169+
10170+
vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
10171+
vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
10172+
10173+
std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };
10174+
10175+
vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
10176+
vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
10177+
vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
10178+
10179+
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
10180+
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
10181+
ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
10182+
10183+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10184+
ggml_vk_sync_buffers(ctx, subctx);
10185+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10186+
ggml_vk_sync_buffers(ctx, subctx);
10187+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10188+
10189+
ctx->prealloc_x_need_sync = true;
10190+
ctx->prealloc_y_need_sync = true;
10191+
}
1013210192
}
1013310193

1013410194
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#version 450
2+
3+
#include "soft_max_large_common.glsl"
4+
5+
void main() {
6+
const uint tid = gl_LocalInvocationID.x;
7+
const uint rowx = gl_WorkGroupID.y;
8+
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
9+
10+
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
11+
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
12+
const uint32_t i01 = rowx % p.ne01;
13+
14+
uint rowy_start = 0;
15+
if (p.KY > 0) {
16+
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
17+
}
18+
19+
if (rowx >= p.nrows_x) {
20+
return;
21+
}
22+
23+
float slope = get_slope(rowx);
24+
25+
// Find max
26+
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
27+
28+
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
29+
const uint col = col0 + tid;
30+
31+
FLOAT_TYPE a = FLOAT_TYPE(0);
32+
if (col < p.KX) {
33+
a = data_a[rowx * p.KX + col];
34+
}
35+
36+
FLOAT_TYPE b = FLOAT_TYPE(0);
37+
if (p.KY > 0 && col < p.KX) {
38+
b = data_b[rowy_start + col];
39+
}
40+
41+
FLOAT_TYPE v = a * p.scale + slope * b;
42+
43+
if (col < p.KX) {
44+
max_val = max(max_val, v);
45+
}
46+
}
47+
48+
// reduce across the workgroup
49+
vals[tid] = max_val;
50+
barrier();
51+
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
52+
if (tid < s) {
53+
vals[tid] = max(vals[tid], vals[tid + s]);
54+
}
55+
barrier();
56+
}
57+
58+
if (tid == 0) {
59+
max_val = vals[0];
60+
data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;
61+
}
62+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#version 450
2+
3+
#include "soft_max_large_common.glsl"
4+
5+
void main() {
6+
const uint tid = gl_LocalInvocationID.x;
7+
const uint rowx = gl_WorkGroupID.y;
8+
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
9+
10+
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
11+
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
12+
const uint32_t i01 = rowx % p.ne01;
13+
14+
uint rowy_start = 0;
15+
if (p.KY > 0) {
16+
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
17+
}
18+
19+
if (rowx >= p.nrows_x) {
20+
return;
21+
}
22+
23+
float slope = get_slope(rowx);
24+
25+
// Find max
26+
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
27+
28+
[[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
29+
if (i + tid < gl_NumWorkGroups.x) {
30+
max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
31+
}
32+
}
33+
34+
// reduce across the workgroup
35+
vals[tid] = max_val;
36+
barrier();
37+
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
38+
if (tid < s) {
39+
vals[tid] = max(max_val, vals[tid + s]);
40+
}
41+
barrier();
42+
}
43+
44+
max_val = vals[0];
45+
barrier();
46+
47+
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
48+
49+
// Compute sum{exp(x - max)}
50+
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
51+
const uint col = col0 + tid;
52+
53+
if (col >= p.KX) {
54+
break;
55+
}
56+
57+
// compute exp(a*scale+b*slope), add it to sum
58+
const uint i = rowx * p.KX + col;
59+
FLOAT_TYPE val;
60+
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
61+
sum += val;
62+
data_d[i] = D_TYPE(val);
63+
}
64+
65+
// reduce across the workgroup
66+
vals[tid] = sum;
67+
barrier();
68+
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
69+
if (tid < s) {
70+
vals[tid] += vals[tid + s];
71+
}
72+
barrier();
73+
}
74+
75+
if (tid == 0) {
76+
sum = vals[0];
77+
data_s[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = sum;
78+
}
79+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#version 450
2+
3+
#include "soft_max_large_common.glsl"
4+
5+
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
6+
7+
void main() {
8+
const uint tid = gl_LocalInvocationID.x;
9+
const uint rowx = gl_WorkGroupID.y;
10+
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
11+
12+
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
13+
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
14+
const uint32_t i01 = rowx % p.ne01;
15+
16+
uint rowy_start = 0;
17+
if (p.KY > 0) {
18+
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
19+
}
20+
21+
if (rowx >= p.nrows_x) {
22+
return;
23+
}
24+
25+
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
26+
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
27+
28+
[[unroll]] for (uint i = 0; i < gl_NumWorkGroups.x; i += BLOCK_SIZE) {
29+
if (i + tid < gl_NumWorkGroups.x) {
30+
max_val = max(max_val, data_m[rowx * gl_NumWorkGroups.x + i + tid]);
31+
sum += data_s[rowx * gl_NumWorkGroups.x + i + tid];
32+
}
33+
}
34+
35+
// reduce across the workgroup
36+
vals[tid] = max_val;
37+
sumsh[tid] = sum;
38+
barrier();
39+
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
40+
if (tid < s) {
41+
vals[tid] = max(max_val, vals[tid + s]);
42+
sumsh[tid] += sumsh[tid + s];
43+
}
44+
barrier();
45+
}
46+
47+
max_val = vals[0];
48+
sum = sumsh[0];
49+
50+
if (p.has_sinks != 0) {
51+
sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
52+
}
53+
54+
FLOAT_TYPE rcpdivisor = 1.0/sum;
55+
56+
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
57+
const uint col = col0 + tid;
58+
59+
if (col >= p.KX) {
60+
continue;
61+
}
62+
63+
data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
64+
}
65+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#extension GL_EXT_control_flow_attributes : enable
2+
3+
layout (push_constant) uniform parameter
4+
{
5+
uint KX;
6+
uint KY;
7+
uint ne00;
8+
uint ne01;
9+
uint ne02;
10+
uint ne12;
11+
uint ne13;
12+
uint nb11;
13+
uint nb12;
14+
uint nb13;
15+
float scale;
16+
float max_bias;
17+
float m0;
18+
float m1;
19+
uint n_head_log2;
20+
uint nrows_x;
21+
uint has_sinks;
22+
} p;
23+
24+
#include "types.glsl"
25+
26+
layout(constant_id = 0) const uint BLOCK_SIZE = 128;
27+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
28+
layout(constant_id = 1) const uint num_iters = 4;
29+
30+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
31+
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
32+
layout (binding = 2) readonly buffer Z {float data_c[];};
33+
layout (binding = 3) buffer D {D_TYPE data_d[];};
34+
layout (binding = 4) buffer M {float data_m[];};
35+
layout (binding = 5) buffer S {float data_s[];};
36+
37+
shared FLOAT_TYPE vals[BLOCK_SIZE];
38+
39+
float get_slope(uint rowx) {
40+
float slope = 1.0f;
41+
42+
// ALiBi
43+
if (p.max_bias > 0.0f) {
44+
const uint h = (rowx / p.ne01) % p.ne02; // head index
45+
46+
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
47+
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
48+
49+
slope = pow(base, exp);
50+
}
51+
52+
return slope;
53+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,13 @@ void process_shaders() {
899899
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
900900
string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
901901

902+
string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
903+
string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
904+
string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
905+
string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
906+
string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
907+
string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
908+
902909
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
903910
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
904911
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7652,6 +7652,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
76527652
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
76537653
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
76547654

7655+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
7656+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
7657+
76557658
for (float max_bias : {0.0f, 8.0f}) {
76567659
for (float scale : {1.0f, 0.1f}) {
76577660
for (int64_t ne0 : {16, 1024}) {

0 commit comments

Comments
 (0)