Skip to content

Commit bae9a58

Browse files
committed
vulkan: enable the use of simpler softmax shaders
Even though the regular softmax shaders successfully pass test-backend-ops with Apple GPUs, running long inference tests has shown the models end derailing with softmax OPs being the root cause. With this commit, we use simpler softmax shaders borrowed from the Kompute backend (which are basically reimplementations of the Metal shaders) on certain GPUs know to have problem with the regular ones. Signed-off-by: Sergio Lopez <[email protected]>
1 parent 8b47fce commit bae9a58

File tree

3 files changed

+123
-9
lines changed

3 files changed

+123
-9
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ struct vk_device_struct {
228228
vk_pipeline pipeline_simpler_mul_mat_q6_k;
229229
vk_pipeline pipeline_simpler_mul_mat_q8_0;
230230

231+
vk_pipeline pipeline_simpler_soft_max_f16;
232+
vk_pipeline pipeline_simpler_soft_max_f32;
233+
231234
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
232235
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
233236
vk_pipeline pipeline_acc_f32;
@@ -518,6 +521,18 @@ struct vk_op_soft_max_push_constants {
518521
uint32_t nrows_x;
519522
};
520523

524+
struct vk_op_simpler_soft_max_push_constants {
525+
int32_t ne00;
526+
int32_t ne01;
527+
int32_t ne02;
528+
float scale;
529+
float max_bias;
530+
float m0;
531+
float m1;
532+
uint32_t n_head_log2;
533+
int32_t mask;
534+
};
535+
521536
struct vk_op_argsort_push_constants {
522537
uint32_t ncols;
523538
uint32_t ncols_pad;
@@ -2088,6 +2103,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
20882103
ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1);
20892104
ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1);
20902105

2106+
ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f16, "simpler_soft_max_f16", simpler_soft_max_f16_len, simpler_soft_max_f16_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1);
2107+
ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f32, "simpler_soft_max_f32", simpler_soft_max_f32_len, simpler_soft_max_f32_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1);
2108+
20912109
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
20922110
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
20932111
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -5440,6 +5458,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
54405458
case GGML_OP_SOFT_MAX:
54415459
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
54425460

5461+
if (ctx->device->simpler_shaders) {
5462+
if (src1 && src1->type == GGML_TYPE_F16) {
5463+
return ctx->device->pipeline_simpler_soft_max_f16;
5464+
}
5465+
return ctx->device->pipeline_simpler_soft_max_f32;
5466+
}
5467+
54435468
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
54445469
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
54455470
}
@@ -5738,9 +5763,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57385763
GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
57395764

57405765
switch (op) {
5766+
case GGML_OP_SOFT_MAX:
5767+
if (ctx->device->simpler_shaders) {
5768+
elements = { (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3] };
5769+
break;
5770+
}
5771+
// fall-through
57415772
case GGML_OP_NORM:
57425773
case GGML_OP_RMS_NORM:
5743-
case GGML_OP_SOFT_MAX:
57445774
case GGML_OP_SUM_ROWS:
57455775
{
57465776
const uint32_t nr = ggml_nrows(src0);
@@ -6281,14 +6311,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
62816311
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
62826312
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
62836313

6284-
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
6285-
ncols,
6286-
src1 != nullptr ? nrows_y : (uint32_t)0,
6287-
scale, max_bias,
6288-
m0, m1,
6289-
n_head_log2,
6290-
nrows_x,
6291-
}, dryrun);
6314+
if (ctx->device->simpler_shaders) {
6315+
ggml_vk_op_f32<vk_op_simpler_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
6316+
(int32_t) src0->ne[0],
6317+
(int32_t) src0->ne[1],
6318+
(int32_t) src0->ne[2],
6319+
scale, max_bias,
6320+
m0, m1,
6321+
n_head_log2,
6322+
src1 == nullptr ? 0 : 1,
6323+
}, dryrun);
6324+
} else {
6325+
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
6326+
ncols,
6327+
src1 != nullptr ? nrows_y : (uint32_t)0,
6328+
scale, max_bias,
6329+
m0, m1,
6330+
n_head_log2,
6331+
nrows_x,
6332+
}, dryrun);
6333+
}
62926334
}
62936335

62946336
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
2+
3+
#version 450
4+
5+
#include "simpler_common.comp"
6+
7+
layout(local_size_x = 32) in;
8+
9+
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
10+
layout(binding = 1) buffer restrict readonly tensorInB { A_TYPE inB[]; };
11+
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
12+
13+
layout(push_constant) uniform PushConstants {
14+
int ne00;
15+
int ne01;
16+
int ne02;
17+
float scale;
18+
float max_bias;
19+
float m0;
20+
float m1;
21+
uint n_head_log2;
22+
int mask;
23+
} pcs;
24+
25+
void main() {
26+
if (gl_SubgroupInvocationID > 31)
27+
return;
28+
29+
const uint i03 = gl_WorkGroupID.z;
30+
const uint i02 = gl_WorkGroupID.y;
31+
const uint i01 = gl_WorkGroupID.x;
32+
33+
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
34+
const uint psrc0 = extra_off;
35+
const uint pmask = i01*pcs.ne00;
36+
const uint pdst = extra_off;
37+
38+
float slope = 1.0f;
39+
40+
// ALiBi
41+
if (pcs.max_bias > 0.0f) {
42+
int64_t h = i02;
43+
44+
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
45+
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
46+
47+
slope = pow(base, float(exp));
48+
}
49+
50+
// parallel max
51+
float localMax = uintBitsToFloat(0xFF800000);
52+
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
53+
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
54+
}
55+
float max_ = subgroupMax(localMax);
56+
57+
// parallel sum
58+
float localSum = 0.0f;
59+
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
60+
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
61+
localSum += exp_psrc0;
62+
out_[pdst + i00] = exp_psrc0;
63+
}
64+
65+
const float sum = subgroupAdd(localSum);
66+
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
67+
out_[pdst + i00] /= sum;
68+
}
69+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,9 @@ void process_shaders() {
504504
string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {});
505505
string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {});
506506

507+
string_to_spv("simpler_soft_max_f16", "simpler_soft_max.comp", {{"A_TYPE", "float16_t"}});
508+
string_to_spv("simpler_soft_max_f32", "simpler_soft_max.comp", {{"A_TYPE", "float"}});
509+
507510
for (auto &c : compiles) {
508511
c.wait();
509512
}

0 commit comments

Comments
 (0)