@@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
323323 }
324324
325325 base_dict[" ACC_TYPE" ] = f16acc ? " float16_t" : " float" ;
326+ if (f16acc) {
327+ base_dict[" ACC_TYPE_MAX" ] = " \" float16_t(65504.0)\" " ;
328+ }
326329
327330 if (coopmat) {
328331 base_dict[" COOPMAT" ] = " 1" ;
@@ -437,8 +440,12 @@ void process_shaders() {
437440
438441 // flash attention
439442 for (const auto & f16acc : {false , true }) {
440- std::string acctype = f16acc ? " float16_t" : " float" ;
441- std::string acctypev4 = f16acc ? " f16vec4" : " vec4" ;
443+ std::map<std::string, std::string> fa_base_dict = base_dict;
444+ fa_base_dict[" ACC_TYPE" ] = f16acc ? " float16_t" : " float" ;
445+ fa_base_dict[" ACC_TYPEV4" ] = f16acc ? " f16vec4" : " vec4" ;
446+ if (f16acc) {
447+ fa_base_dict[" ACC_TYPE_MAX" ] = " \" float16_t(65504.0)\" " ;
448+ }
442449
443450 for (const auto & tname : type_names) {
444451 if (tname == " f32" ) {
@@ -449,30 +456,30 @@ void process_shaders() {
449456#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
450457 if (tname == " f16" ) {
451458 string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn_cm2.comp" ,
452- merge_maps (base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype }}), true , false , true , f16acc);
459+ merge_maps (fa_base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }}), true , false , true , f16acc);
453460 } else {
454461 std::string data_a_key = " DATA_A_" + to_uppercase (tname);
455462 string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn_cm2.comp" ,
456- merge_maps (base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype }, {" DEQUANTFUNC" , " dequantFunc" +to_uppercase (tname) }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname) }}), true , false , true , f16acc);
463+ merge_maps (fa_base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, {" DEQUANTFUNC" , " dequantFunc" +to_uppercase (tname) }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname) }}), true , false , true , f16acc);
457464 }
458465#endif
459466#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
460467 if (tname == " f16" ) {
461468 string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn_cm1.comp" ,
462- merge_maps (base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype}, { " ACC_TYPEV4 " , acctypev4 }, {" COOPMAT" , " 1" }}), true , true , false , f16acc);
469+ merge_maps (fa_base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, {" COOPMAT" , " 1" }}), true , true , false , f16acc);
463470 } else if (tname == " q4_0" || tname == " q8_0" ) {
464471 std::string data_a_key = " DATA_A_" + to_uppercase (tname);
465472 string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn_cm1.comp" ,
466- merge_maps (base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype}, { " ACC_TYPEV4 " , acctypev4 }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname)}, {" COOPMAT" , " 1" }}), true , true , false , f16acc);
473+ merge_maps (fa_base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname)}, {" COOPMAT" , " 1" }}), true , true , false , f16acc);
467474 }
468475#endif
469476 if (tname == " f16" ) {
470477 string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn.comp" ,
471- merge_maps (base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype }}), true , false , false , f16acc);
478+ merge_maps (fa_base_dict , {{" Q_TYPE" , " float" }, {" D_TYPE" , " float" }}), true , false , false , f16acc);
472479 } else if (tname == " q4_0" || tname == " q8_0" ) {
473480 std::string data_a_key = " DATA_A_" + to_uppercase (tname);
474481 string_to_spv (" flash_attn_f32_f16_" + tname, " flash_attn.comp" ,
475- merge_maps (base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, { " ACC_TYPE " , acctype }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname) }}), true , false , false , f16acc);
482+ merge_maps (fa_base_dict , {{data_a_key, " 1" }, {" Q_TYPE" , " float" }, {" D_TYPE" , " float" }, {" BLOCK_SIZE" , " QUANT_K_" +to_uppercase (tname) }}), true , false , false , f16acc);
476483 }
477484 }
478485 }
0 commit comments