@@ -3249,7 +3249,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_
32493249template  [[host_name(" kernel_flash_attn_ext_f16_h128"  )]]  kernel flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES, half4x4,    1 , dequantize_f16,  half4x4,    1 , dequantize_f16,  128 >;
32503250template  [[host_name(" kernel_flash_attn_ext_f16_h256"  )]]  kernel flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES, half4x4,    1 , dequantize_f16,  half4x4,    1 , dequantize_f16,  256 >;
32513251
3252- #if  ! defined(GGML_METAL_NO_BFLOAT )
3252+ #if  defined(GGML_METAL_USE_BF16 )
32533253template  [[host_name(" kernel_flash_attn_ext_bf16_h64"   )]] kernel flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 64 >;
32543254template  [[host_name(" kernel_flash_attn_ext_bf16_h80"   )]] kernel flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 80 >;
32553255template  [[host_name(" kernel_flash_attn_ext_bf16_h96"   )]] kernel flash_attn_ext_t  kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1 , dequantize_bf16, bfloat4x4,  1 , dequantize_bf16, 96 >;
@@ -3634,7 +3634,7 @@ kernel void kernel_flash_attn_ext_vec(
36343634typedef  decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
36353635
36363636template  [[host_name(" kernel_flash_attn_ext_vec_f16_h128"  )]]  kernel flash_attn_ext_vec_t  kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1 , dequantize_f16,  half4x4,     1 , dequantize_f16,  128 >;
3637- #if  ! defined(GGML_METAL_NO_BFLOAT )
3637+ #if  defined(GGML_METAL_USE_BF16 )
36383638template  [[host_name(" kernel_flash_attn_ext_vec_bf16_h128"  )]] kernel flash_attn_ext_vec_t  kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1 , dequantize_bf16, bfloat4x4,   1 , dequantize_bf16, 128 >;
36393639#endif 
36403640template  [[host_name(" kernel_flash_attn_ext_vec_q4_0_h128"  )]] kernel flash_attn_ext_vec_t  kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0,  2 , dequantize_q4_0, 128 >;
@@ -3644,7 +3644,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_
36443644template  [[host_name(" kernel_flash_attn_ext_vec_q8_0_h128"  )]] kernel flash_attn_ext_vec_t  kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0,  2 , dequantize_q8_0, 128 >;
36453645
36463646template  [[host_name(" kernel_flash_attn_ext_vec_f16_h256"  )]]  kernel flash_attn_ext_vec_t  kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1 , dequantize_f16,  half4x4,     1 , dequantize_f16,  256 >;
3647- #if  ! defined(GGML_METAL_NO_BFLOAT )
3647+ #if  defined(GGML_METAL_USE_BF16 )
36483648template  [[host_name(" kernel_flash_attn_ext_vec_bf16_h256"  )]] kernel flash_attn_ext_vec_t  kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1 , dequantize_bf16, bfloat4x4,   1 , dequantize_bf16, 256 >;
36493649#endif 
36503650template  [[host_name(" kernel_flash_attn_ext_vec_q4_0_h256"  )]] kernel flash_attn_ext_vec_t  kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0,  2 , dequantize_q4_0, 256 >;
0 commit comments