@@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
323
323
}
324
324
325
325
base_dict[" ACC_TYPE" ] = f16acc ? " float16_t" : " float" ;
326
+ if (f16acc) {
327
+ base_dict[" ACC_TYPE_MAX" ] = " \" float16_t(65504.0)\" " ;
328
+ }
326
329
327
330
if (coopmat) {
328
331
base_dict[" COOPMAT" ] = " 1" ;
@@ -437,8 +440,12 @@ void process_shaders() {
437
440
438
441
// flash attention
439
442
for (const auto & f16acc : {false , true }) {
440
- std::string acctype = f16acc ? " float16_t" : " float" ;
441
- std::string acctypev4 = f16acc ? " f16vec4" : " vec4" ;
443
+ std::map<std::string, std::string> fa_base_dict = base_dict;
444
+ fa_base_dict[" ACC_TYPE" ] = f16acc ? " float16_t" : " float" ;
445
+ fa_base_dict[" ACC_TYPEV4" ] = f16acc ? " f16vec4" : " vec4" ;
446
+ if (f16acc) {
447
+ fa_base_dict[" ACC_TYPE_MAX" ] = " \" float16_t(65504.0)\" " ;
448
+ }
442
449
443
450
for (const auto & tname : type_names) {
444
451
if (tname == " f32" ) {
@@ -449,30 +456,30 @@ void process_shaders() {
449
456
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
450
457
if (tname == " f16" ) {
451
458
string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn_cm2.comp" ,
452
- merge_maps (base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype }}), true , false , true , f16acc);
459
+ merge_maps (fa_base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }}), true , false , true , f16acc);
453
460
} else {
454
461
std::string data_a_key = " DATA_A_" + to_uppercase (tname);
455
462
string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn_cm2.comp" ,
456
- merge_maps (base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype }, {" DEQUANTFUNC" , " dequantFunc" +to_uppercase (tname) }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname) }}), true , false , true , f16acc);
463
+ merge_maps (fa_base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, {" DEQUANTFUNC" , " dequantFunc" +to_uppercase (tname) }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname) }}), true , false , true , f16acc);
457
464
}
458
465
#endif
459
466
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
460
467
if (tname == " f16" ) {
461
468
string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn_cm1.comp" ,
462
- merge_maps (base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype}, { " ACC_TYPEV4 " , acctypev4 }, {" COOPMAT" , " 1" }}), true , true , false , f16acc);
469
+ merge_maps (fa_base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, {" COOPMAT" , " 1" }}), true , true , false , f16acc);
463
470
} else if (tname == " q4_0" || tname == " q8_0" ) {
464
471
std::string data_a_key = " DATA_A_" + to_uppercase (tname);
465
472
string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn_cm1.comp" ,
466
- merge_maps (base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype}, { " ACC_TYPEV4 " , acctypev4 }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname)}, {" COOPMAT" , " 1" }}), true , true , false , f16acc);
473
+ merge_maps (fa_base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname)}, {" COOPMAT" , " 1" }}), true , true , false , f16acc);
467
474
}
468
475
#endif
469
476
if (tname == " f16" ) {
470
477
string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn.comp" ,
471
- merge_maps (base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype }}), true , false , false , f16acc);
478
+ merge_maps (fa_base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }}), true , false , false , f16acc);
472
479
} else if (tname == " q4_0" || tname == " q8_0" ) {
473
480
std::string data_a_key = " DATA_A_" + to_uppercase (tname);
474
481
string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn.comp" ,
475
- merge_maps (base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname) }}), true , false , false , f16acc);
482
+ merge_maps (fa_base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname) }}), true , false , false , f16acc);
476
483
}
477
484
}
478
485
}
0 commit comments