Skip to content

Commit aad3a4b

Browse files
committed
vulkan: enable the use of simpler softmax shaders
Even though the regular softmax shaders successfully pass test-backend-ops with Apple GPUs, running long inference tests has shown the models end derailing with softmax OPs being the root cause. With this commit, we use simpler softmax shaders borrowed from the Kompute backend (which are basically reimplementations of the Metal shaders) on certain GPUs know to have problem with the regular ones. Signed-off-by: Sergio Lopez <[email protected]>
1 parent 7855aeb commit aad3a4b

File tree

3 files changed

+123
-9
lines changed

3 files changed

+123
-9
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ struct vk_device_struct {
226226
vk_pipeline pipeline_simpler_mul_mat_q6_k;
227227
vk_pipeline pipeline_simpler_mul_mat_q8_0;
228228

229+
vk_pipeline pipeline_simpler_soft_max_f16;
230+
vk_pipeline pipeline_simpler_soft_max_f32;
231+
229232
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
230233
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
231234
vk_pipeline pipeline_acc_f32;
@@ -516,6 +519,18 @@ struct vk_op_soft_max_push_constants {
516519
uint32_t nrows_x;
517520
};
518521

522+
struct vk_op_simpler_soft_max_push_constants {
523+
int32_t ne00;
524+
int32_t ne01;
525+
int32_t ne02;
526+
float scale;
527+
float max_bias;
528+
float m0;
529+
float m1;
530+
uint32_t n_head_log2;
531+
int32_t mask;
532+
};
533+
519534
struct vk_op_argsort_push_constants {
520535
uint32_t ncols;
521536
uint32_t ncols_pad;
@@ -1983,6 +1998,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
19831998
ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q6_k, "simpler_mul_mat_q6_k", simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {2, device->subgroup_size}, 1);
19841999
ggml_vk_create_pipeline(device, device->pipeline_simpler_mul_mat_q8_0, "simpler_mul_mat_q8_0", simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, "main", 3, 18 * sizeof(uint32_t), {1, 1, 1}, {(device->subgroup_size * 2) / 8}, 1);
19852000

2001+
ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f16, "simpler_soft_max_f16", simpler_soft_max_f16_len, simpler_soft_max_f16_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1);
2002+
ggml_vk_create_pipeline(device, device->pipeline_simpler_soft_max_f32, "simpler_soft_max_f32", simpler_soft_max_f32_len, simpler_soft_max_f32_data, "main", 3, sizeof(vk_op_simpler_soft_max_push_constants), {1, 1, 1}, {}, 1);
2003+
19862004
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
19872005
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
19882006
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -5286,6 +5304,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52865304
case GGML_OP_SOFT_MAX:
52875305
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
52885306

5307+
if (ctx->device->simpler_shaders) {
5308+
if (src1 && src1->type == GGML_TYPE_F16) {
5309+
return ctx->device->pipeline_simpler_soft_max_f16;
5310+
}
5311+
return ctx->device->pipeline_simpler_soft_max_f32;
5312+
}
5313+
52895314
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
52905315
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
52915316
}
@@ -5584,9 +5609,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
55845609
GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
55855610

55865611
switch (op) {
5612+
case GGML_OP_SOFT_MAX:
5613+
if (ctx->device->simpler_shaders) {
5614+
elements = { (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3] };
5615+
break;
5616+
}
5617+
// fall-through
55875618
case GGML_OP_NORM:
55885619
case GGML_OP_RMS_NORM:
5589-
case GGML_OP_SOFT_MAX:
55905620
case GGML_OP_SUM_ROWS:
55915621
{
55925622
const uint32_t nr = ggml_nrows(src0);
@@ -6127,14 +6157,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
61276157
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
61286158
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
61296159

6130-
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
6131-
ncols,
6132-
src1 != nullptr ? nrows_y : (uint32_t)0,
6133-
scale, max_bias,
6134-
m0, m1,
6135-
n_head_log2,
6136-
nrows_x,
6137-
}, dryrun);
6160+
if (ctx->device->simpler_shaders) {
6161+
ggml_vk_op_f32<vk_op_simpler_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
6162+
(int32_t) src0->ne[0],
6163+
(int32_t) src0->ne[1],
6164+
(int32_t) src0->ne[2],
6165+
scale, max_bias,
6166+
m0, m1,
6167+
n_head_log2,
6168+
src1 == nullptr ? 0 : 1,
6169+
}, dryrun);
6170+
} else {
6171+
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
6172+
ncols,
6173+
src1 != nullptr ? nrows_y : (uint32_t)0,
6174+
scale, max_bias,
6175+
m0, m1,
6176+
n_head_log2,
6177+
nrows_x,
6178+
}, dryrun);
6179+
}
61386180
}
61396181

61406182
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
2+
3+
#version 450
4+
5+
#include "simpler_common.comp"
6+
7+
layout(local_size_x = 32) in;
8+
9+
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
10+
layout(binding = 1) buffer restrict readonly tensorInB { A_TYPE inB[]; };
11+
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
12+
13+
layout(push_constant) uniform PushConstants {
14+
int ne00;
15+
int ne01;
16+
int ne02;
17+
float scale;
18+
float max_bias;
19+
float m0;
20+
float m1;
21+
uint n_head_log2;
22+
int mask;
23+
} pcs;
24+
25+
void main() {
26+
if (gl_SubgroupInvocationID > 31)
27+
return;
28+
29+
const uint i03 = gl_WorkGroupID.z;
30+
const uint i02 = gl_WorkGroupID.y;
31+
const uint i01 = gl_WorkGroupID.x;
32+
33+
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
34+
const uint psrc0 = extra_off;
35+
const uint pmask = i01*pcs.ne00;
36+
const uint pdst = extra_off;
37+
38+
float slope = 1.0f;
39+
40+
// ALiBi
41+
if (pcs.max_bias > 0.0f) {
42+
int64_t h = i02;
43+
44+
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
45+
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
46+
47+
slope = pow(base, float(exp));
48+
}
49+
50+
// parallel max
51+
float localMax = uintBitsToFloat(0xFF800000);
52+
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
53+
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
54+
}
55+
float max_ = subgroupMax(localMax);
56+
57+
// parallel sum
58+
float localSum = 0.0f;
59+
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
60+
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
61+
localSum += exp_psrc0;
62+
out_[pdst + i00] = exp_psrc0;
63+
}
64+
65+
const float sum = subgroupAdd(localSum);
66+
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
67+
out_[pdst + i00] /= sum;
68+
}
69+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,9 @@ void process_shaders() {
498498
string_to_spv("simpler_mul_mat_q6_k", "simpler_mul_mat_q6_k.comp", {});
499499
string_to_spv("simpler_mul_mat_q8_0", "simpler_mul_mat_q8_0.comp", {});
500500

501+
string_to_spv("simpler_soft_max_f16", "simpler_soft_max.comp", {{"A_TYPE", "float16_t"}});
502+
string_to_spv("simpler_soft_max_f32", "simpler_soft_max.comp", {{"A_TYPE", "float"}});
503+
501504
for (auto &c : compiles) {
502505
c.wait();
503506
}

0 commit comments

Comments
 (0)