@@ -3546,6 +3546,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_at
35463546template [[host_name(" kernel_flash_attn_ext_f16_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 192 , 192 >;
35473547template [[host_name(" kernel_flash_attn_ext_f16_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 192 , 128 >;
35483548template [[host_name(" kernel_flash_attn_ext_f16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 , 256 >;
3549+ template [[host_name(" kernel_flash_attn_ext_f16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 576 , 512 >;
35493550
35503551#if defined(GGML_METAL_USE_BF16)
35513552template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 , 64 >;
@@ -3556,6 +3557,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_at
35563557template [[host_name(" kernel_flash_attn_ext_bf16_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 192 >;
35573558template [[host_name(" kernel_flash_attn_ext_bf16_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 128 >;
35583559template [[host_name(" kernel_flash_attn_ext_bf16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 , 256 >;
3560+ template [[host_name(" kernel_flash_attn_ext_bf16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
35593561#endif
35603562
35613563template [[host_name(" kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 64 , 64 >;
@@ -3566,6 +3568,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_at
35663568template [[host_name(" kernel_flash_attn_ext_q4_0_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 192 , 192 >;
35673569template [[host_name(" kernel_flash_attn_ext_q4_0_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 192 , 128 >;
35683570template [[host_name(" kernel_flash_attn_ext_q4_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 256 , 256 >;
3571+ template [[host_name(" kernel_flash_attn_ext_q4_0_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 576 , 512 >;
35693572
35703573template [[host_name(" kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 64 , 64 >;
35713574template [[host_name(" kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 80 , 80 >;
@@ -3575,6 +3578,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_at
35753578template [[host_name(" kernel_flash_attn_ext_q4_1_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 192 , 192 >;
35763579template [[host_name(" kernel_flash_attn_ext_q4_1_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 192 , 128 >;
35773580template [[host_name(" kernel_flash_attn_ext_q4_1_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 256 , 256 >;
3581+ template [[host_name(" kernel_flash_attn_ext_q4_1_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 576 , 512 >;
35783582
35793583template [[host_name(" kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 64 , 64 >;
35803584template [[host_name(" kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 80 , 80 >;
@@ -3584,6 +3588,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_at
35843588template [[host_name(" kernel_flash_attn_ext_q5_0_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 192 , 192 >;
35853589template [[host_name(" kernel_flash_attn_ext_q5_0_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 192 , 128 >;
35863590template [[host_name(" kernel_flash_attn_ext_q5_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 256 , 256 >;
3591+ template [[host_name(" kernel_flash_attn_ext_q5_0_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 576 , 512 >;
35873592
35883593template [[host_name(" kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 64 , 64 >;
35893594template [[host_name(" kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 80 , 80 >;
@@ -3593,6 +3598,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_at
35933598template [[host_name(" kernel_flash_attn_ext_q5_1_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 192 , 192 >;
35943599template [[host_name(" kernel_flash_attn_ext_q5_1_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 192 , 128 >;
35953600template [[host_name(" kernel_flash_attn_ext_q5_1_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 256 , 256 >;
3601+ template [[host_name(" kernel_flash_attn_ext_q5_1_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 576 , 512 >;
35963602
35973603template [[host_name(" kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 64 , 64 >;
35983604template [[host_name(" kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 80 , 80 >;
@@ -3602,6 +3608,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_at
36023608template [[host_name(" kernel_flash_attn_ext_q8_0_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 192 , 192 >;
36033609template [[host_name(" kernel_flash_attn_ext_q8_0_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 192 , 128 >;
36043610template [[host_name(" kernel_flash_attn_ext_q8_0_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 256 , 256 >;
3611+ template [[host_name(" kernel_flash_attn_ext_q8_0_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 576 , 512 >;
36053612
36063613#undef FA_TYPES
36073614
@@ -4009,6 +4016,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
40094016template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 256 , 256 , 4 >;
40104017template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 256 , 256 , 4 >;
40114018
4019+ template [[host_name(" kernel_flash_attn_ext_vec_f16_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 576 , 512 , 2 >;
4020+ #if defined(GGML_METAL_USE_BF16)
4021+ template [[host_name(" kernel_flash_attn_ext_vec_bf16_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1 , dequantize_bf16_t4, bfloat4, 1 , dequantize_bf16_t4, 576 , 512 , 2 >;
4022+ #endif
4023+ template [[host_name(" kernel_flash_attn_ext_vec_q4_0_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8 , dequantize_q4_0_t4, block_q4_0, 8 , dequantize_q4_0_t4, 576 , 512 , 2 >;
4024+ template [[host_name(" kernel_flash_attn_ext_vec_q4_1_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8 , dequantize_q4_1_t4, block_q4_1, 8 , dequantize_q4_1_t4, 576 , 512 , 2 >;
4025+ template [[host_name(" kernel_flash_attn_ext_vec_q5_0_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8 , dequantize_q5_0_t4, block_q5_0, 8 , dequantize_q5_0_t4, 576 , 512 , 2 >;
4026+ template [[host_name(" kernel_flash_attn_ext_vec_q5_1_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8 , dequantize_q5_1_t4, block_q5_1, 8 , dequantize_q5_1_t4, 576 , 512 , 2 >;
4027+ template [[host_name(" kernel_flash_attn_ext_vec_q8_0_hk576_hv512" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8 , dequantize_q8_0_t4, block_q8_0, 8 , dequantize_q8_0_t4, 576 , 512 , 2 >;
4028+
40124029#undef FA_TYPES
40134030
40144031template <typename T>
0 commit comments