@@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
31283128 const int iq2 = tgpig[1 ];
31293129 const int iq1 = tgpig[0 ]*Q;
31303130
3131- const short DK4 = DK/4 ;
3132- const short DK8 = DK/8 ;
3133- const short DK16 = DK/16 ;
3134- const short DV4 = DV/4 ;
3135- const short DV8 = DV/8 ;
3136- const short DV16 = DV/16 ;
3137- const short NW = N_SIMDWIDTH;
3138- const short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float)
3131+ constexpr short DK4 = DK/4 ;
3132+ constexpr short DK8 = DK/8 ;
3133+ constexpr short DK16 = DK/16 ;
3134+ constexpr short DV4 = DV/4 ;
3135+ constexpr short DV8 = DV/8 ;
3136+ constexpr short DV16 = DV/16 ;
3137+
3138+ constexpr short NW = N_SIMDWIDTH;
3139+ constexpr short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float)
31393140
31403141 const short TS = nsg*SH; // shared memory size per query in (s_t == float)
31413142 const short T = DK + 2 *TS; // shared memory size per query in (half)
@@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
36413642 const int iq2 = tgpig[1 ];
36423643 const int iq1 = tgpig[0 ];
36433644
3644- const short DK4 = DK/4 ;
3645- const short DV4 = DV/4 ;
3646- const short NW = N_SIMDWIDTH;
3647- const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3648- const short SH = 2 *C; // shared memory per simdgroup
3645+ constexpr short DK4 = DK/4 ;
3646+ constexpr short DV4 = DV/4 ;
3647+ constexpr short NW = N_SIMDWIDTH;
3648+ constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
3649+ constexpr short SH = 2 *C; // shared memory per simdgroup
36493650
36503651 const short T = DK + nsg*SH; // shared memory size per query in (half)
36513652
0 commit comments