Skip to content

Commit 94e82c7

Browse files
authored
vulkan: clamp matmul and FA results to the max finite value (ggml-org#15652)
* vulkan: clamp matmul and FA results to the max finite value * only clamp for fp16
1 parent 4d74393 commit 94e82c7

File tree

7 files changed

+58
-8
lines changed

7 files changed

+58
-8
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,9 @@ void main() {
334334
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
335335
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
336336
Of[r][d] *= Lfrcp[r];
337+
#if defined(ACC_TYPE_MAX)
338+
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
339+
#endif
337340
}
338341
}
339342

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,9 @@ void main() {
373373
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
374374
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
375375
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
376+
#if defined(ACC_TYPE_MAX)
377+
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
378+
#endif
376379
}
377380
}
378381

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ void main() {
283283

284284
O = Ldiag*O;
285285

286+
#if defined(ACC_TYPE_MAX)
287+
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
288+
#endif
289+
286290
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
287291

288292
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ void main() {
111111
}
112112
}
113113
O *= L;
114+
115+
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
116+
O = clamp(O, -FLT_MAX, FLT_MAX);
117+
114118
data_d[iq3 * D * N + D * n + d] = O;
115119
}
116120
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,20 @@ void main() {
891891
barrier();
892892
}
893893

894+
#if defined(ACC_TYPE_MAX)
895+
#ifdef COOPMAT
896+
[[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
897+
[[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
898+
sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
899+
}
900+
}
901+
#else
902+
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
903+
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
904+
}
905+
#endif
906+
#endif
907+
894908
const uint dr = ir * BM + warp_r * WM;
895909
const uint dc = ic * BN + warp_c * WN;
896910

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,10 @@ void main() {
349349
sum = coopMatMulAdd(mat_a, mat_b, sum);
350350
block_k += BK;
351351
}
352+
#if defined(ACC_TYPE_MAX)
353+
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
354+
#endif
355+
352356
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
353357

354358
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
@@ -388,6 +392,10 @@ void main() {
388392
sum = coopMatMulAdd(mat_a, mat_b, sum);
389393
block_k += BK;
390394
}
395+
#if defined(ACC_TYPE_MAX)
396+
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
397+
#endif
398+
391399
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
392400

393401
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
@@ -428,6 +436,10 @@ void main() {
428436
sum = coopMatMulAdd(mat_a, mat_b, sum);
429437
block_k += BK;
430438
}
439+
#if defined(ACC_TYPE_MAX)
440+
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
441+
#endif
442+
431443
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
432444

433445
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
@@ -485,6 +497,9 @@ void main() {
485497
sum = coopMatMulAdd(mat_a, mat_b, sum);
486498
}
487499
}
500+
#if defined(ACC_TYPE_MAX)
501+
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
502+
#endif
488503

489504
// Convert from ACC_TYPE to D_TYPE
490505
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
323323
}
324324

325325
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
326+
if (f16acc) {
327+
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
328+
}
326329

327330
if (coopmat) {
328331
base_dict["COOPMAT"] = "1";
@@ -437,8 +440,12 @@ void process_shaders() {
437440

438441
// flash attention
439442
for (const auto& f16acc : {false, true}) {
440-
std::string acctype = f16acc ? "float16_t" : "float";
441-
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
443+
std::map<std::string, std::string> fa_base_dict = base_dict;
444+
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
445+
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
446+
if (f16acc) {
447+
fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
448+
}
442449

443450
for (const auto& tname : type_names) {
444451
if (tname == "f32") {
@@ -449,30 +456,30 @@ void process_shaders() {
449456
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
450457
if (tname == "f16") {
451458
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
452-
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
459+
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
453460
} else {
454461
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
455462
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
456-
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
463+
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
457464
}
458465
#endif
459466
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
460467
if (tname == "f16") {
461468
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
462-
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
469+
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
463470
} else if (tname == "q4_0" || tname == "q8_0") {
464471
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
465472
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
466-
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
473+
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
467474
}
468475
#endif
469476
if (tname == "f16") {
470477
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
471-
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
478+
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
472479
} else if (tname == "q4_0" || tname == "q8_0") {
473480
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
474481
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
475-
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
482+
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
476483
}
477484
}
478485
}

0 commit comments

Comments
 (0)