@@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
33563356 const short D4 = D/4 ;
33573357 const short D16 = D/16 ;
33583358 const short NW = N_SIMDWIDTH;
3359- const short NL = NW/4 ;
3360- const short SH = 2 *C; // shared memory per simdgroup
3359+ const short NL = NW/4 ; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
3360+ const short SH = 2 *C; // shared memory per simdgroup
33613361
33623362 const short T = D + nsg*SH; // shared memory size per query in (half)
33633363
@@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
34483448
34493449 // Q*K^T
34503450 {
3451- // each simdgroup processes 1 query and 4 keys
3451+ // each simdgroup processes 1 query and 4 (NW/NL) keys
34523452 for (short cc = 0 ; cc < C/4 ; ++cc) {
34533453 qk_t mqka[4 ] = { 0.0 , 0.0 , 0.0 , 0.0 };
34543454
@@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
36463646 half, half4, half4x4, \
36473647 half4x4
36483648
3649- typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
3649+ typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >) flash_attn_ext_vec_t;
36503650
36513651template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >;
36523652#if defined(GGML_METAL_USE_BF16)
0 commit comments