@@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
31283128 const int iq2 = tgpig[1 ];
31293129 const int iq1 = tgpig[0 ]*Q;
31303130
3131- const short DK4 = DK/4 ;
3132- const short DK8 = DK/8 ;
3133- const short DK16 = DK/16 ;
3134- const short DV4 = DV/4 ;
3135- const short DV8 = DV/8 ;
3136- const short DV16 = DV/16 ;
3137- const short NW = N_SIMDWIDTH;
3138- const short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float)
3131+ constexpr short DK4 = DK/4 ;
3132+ constexpr short DK8 = DK/8 ;
3133+ constexpr short DK16 = DK/16 ;
3134+ constexpr short DV4 = DV/4 ;
3135+ constexpr short DV8 = DV/8 ;
3136+ constexpr short DV16 = DV/16 ;
3137+
3138+ constexpr short NW = N_SIMDWIDTH;
3139+ constexpr short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float)
31393140
31403141 const short TS = nsg*SH; // shared memory size per query in (s_t == float)
31413142 const short T = DK + 2 *TS; // shared memory size per query in (half)
@@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
36413642 const int iq2 = tgpig[1 ];
36423643 const int iq1 = tgpig[0 ];
36433644
3644- const short DK4 = DK/4 ;
3645- const short DV4 = DV/4 ;
3646- const short NW = N_SIMDWIDTH;
3647- const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3648- const short SH = 2 *C; // shared memory per simdgroup
3645+ constexpr short DK4 = DK/4 ;
3646+ constexpr short DV4 = DV/4 ;
3647+ constexpr short NW = N_SIMDWIDTH;
3648+ constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649+ constexpr short SH = 2 *C; // shared memory per simdgroup
36493650
36503651 const short T = DK + nsg*SH; // shared memory size per query in (half)
36513652
@@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
39563957 half, half4, \
39573958 half4
39583959
3959- typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 128 >) flash_attn_ext_vec_t;
3960+ typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
39603961
39613962template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >;
39623963#if defined(GGML_METAL_USE_BF16)
0 commit comments