@@ -57,35 +57,36 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
5757    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
5858
5959    const  block_q4_0 * K_q4_0 = (const  block_q4_0 *) K_c;
60+     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
6061    GGML_UNUSED (Q_v);
6162
6263    T sum = 0 .0f ;
6364
6465#pragma  unroll
65-     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
66+     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
6667        const  int  k_KQ = k_KQ_0 + threadIdx .x ;
6768
6869        const  int  ib    = k_KQ /  QI8_1;
6970        const  int  iqs4  = k_KQ %  QI4_0;
7071        const  int  shift = k_KQ & (QI8_1/2 );
7172
7273        const  int  v = (get_int_b2 (K_q4_0[ib].qs , iqs4) >> shift) & 0x0F0F0F0F ;
73-         const  int  u = Q_q8[k_KQ_0/WARP_SIZE ];
74+         const  int  u = Q_q8[k_KQ_0/warp_size ];
7475
7576        const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
7677
7778#ifdef  FP16_AVAILABLE
7879        if  (std::is_same<T, half>::value) {
7980            const  half2  * Q_ds = (const  half2  *) Q_ds_v;
8081
81-             const  half2 sum2 = __half2half2 (K_q4_0[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE ];
82+             const  half2 sum2 = __half2half2 (K_q4_0[ib].d ) * Q_ds[k_KQ_0/warp_size ];
8283            sum += (T) (((half) sumi)*__low2half (sum2) - __high2half (sum2) /*  *8/QI8_1 == 1 */  );
8384        } else 
8485#endif  //  FP16_AVAILABLE
8586        {
8687            const  float2  * Q_ds = (const  float2  *) Q_ds_v;
8788
88-             sum += (T) (__half2float (K_q4_0[ib].d ) * (sumi*Q_ds[k_KQ_0/WARP_SIZE ].x  - (8 /QI8_1)*Q_ds[k_KQ_0/WARP_SIZE ].y ));
89+             sum += (T) (__half2float (K_q4_0[ib].d ) * (sumi*Q_ds[k_KQ_0/warp_size ].x  - (8 /QI8_1)*Q_ds[k_KQ_0/warp_size ].y ));
8990        }
9091    }
9192
@@ -97,37 +98,38 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
9798    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
9899
99100    const  block_q4_1 * K_q4_1 = (const  block_q4_1 *) K_c;
101+     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
100102    GGML_UNUSED (Q_v);
101103
102104    T sum = 0 .0f ;
103105
104106#pragma  unroll
105-     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
107+     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
106108        const  int  k_KQ = k_KQ_0 + threadIdx .x ;
107109
108110        const  int  ib    = k_KQ /  QI8_1;
109111        const  int  iqs4  = k_KQ %  QI4_1;
110112        const  int  shift = k_KQ & (QI8_1/2 );
111113
112114        const  int  v = (get_int_b4 (K_q4_1[ib].qs , iqs4) >> shift) & 0x0F0F0F0F ;
113-         const  int  u = Q_q8[k_KQ_0/WARP_SIZE ];
115+         const  int  u = Q_q8[k_KQ_0/warp_size ];
114116
115117        const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
116118
117119#ifdef  FP16_AVAILABLE
118120        if  (std::is_same<T, half>::value) {
119121            const  half2  * Q_ds = (const  half2  *) Q_ds_v;
120122
121-             const  half2 d4d8_m4s8 = K_q4_1[ib].dm  * Q_ds[k_KQ_0/WARP_SIZE ];
123+             const  half2 d4d8_m4s8 = K_q4_1[ib].dm  * Q_ds[k_KQ_0/warp_size ];
122124            const  half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2 (sumi, 1 .0f /QI8_1);
123125            sum += (T) (__low2half (sumid4d8_m4s8scaled) + __high2half (sumid4d8_m4s8scaled));
124126        } else 
125127#endif  //  FP16_AVAILABLE
126128        {
127129            const  float2  * Q_ds = (const  float2  *) Q_ds_v;
128130
129-             const  float  sumid4d8   =  __low2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].x  * sumi;
130-             const  float  m4s8scaled = __high2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].y  / QI8_1;
131+             const  float  sumid4d8   =  __low2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].x  * sumi;
132+             const  float  m4s8scaled = __high2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].y  / QI8_1;
131133
132134            sum += (T) (sumid4d8 + m4s8scaled);
133135        }
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141143    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
142144
143145    const  block_q5_0 * K_q5_0 = (const  block_q5_0 *) K_c;
146+     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
144147    GGML_UNUSED (Q_v);
145148
146149    T sum = 0 .0f ;
147150
148151#pragma  unroll
149-     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
152+     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
150153        const  int  k_KQ = k_KQ_0 + threadIdx .x ;
151154
152155        const  int  ib    = k_KQ /  QI8_1;
@@ -161,22 +164,22 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
161164        v |= (vh << 18 ) & 0x00100000 ; //  2 -> 20
162165        v |= (vh << 25 ) & 0x10000000 ; //  3 -> 28
163166
164-         const  int  u = Q_q8[k_KQ_0/WARP_SIZE ];
167+         const  int  u = Q_q8[k_KQ_0/warp_size ];
165168
166169        const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
167170
168171#ifdef  FP16_AVAILABLE
169172        if  (std::is_same<T, half>::value) {
170173            const  half2  * Q_ds = (const  half2  *) Q_ds_v;
171174
172-             const  half2 sum2 = __half2half2 (K_q5_0[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE ];
175+             const  half2 sum2 = __half2half2 (K_q5_0[ib].d ) * Q_ds[k_KQ_0/warp_size ];
173176            sum += (T) (((half) sumi)*__low2half (sum2) - __high2half (sum2)*__float2half (2 .0f )) /*  *16/QI8_1 == 2 */  ;
174177        } else 
175178#endif  //  FP16_AVAILABLE
176179        {
177180            const  float2  * Q_ds = (const  float2  *) Q_ds_v;
178181
179-             sum += (T) (__half2float (K_q5_0[ib].d ) * (sumi*Q_ds[k_KQ_0/WARP_SIZE ].x  - (16 /QI8_1)*Q_ds[k_KQ_0/WARP_SIZE ].y ));
182+             sum += (T) (__half2float (K_q5_0[ib].d ) * (sumi*Q_ds[k_KQ_0/warp_size ].x  - (16 /QI8_1)*Q_ds[k_KQ_0/warp_size ].y ));
180183        }
181184    }
182185
@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
188191    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
189192
190193    const  block_q5_1 * K_q5_1 = (const  block_q5_1 *) K_c;
194+     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
191195    GGML_UNUSED (Q_v);
192196
193197    T sum = 0 .0f ;
194198
195199#pragma  unroll
196-     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
200+     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
197201        const  int  k_KQ = k_KQ_0 + threadIdx .x ;
198202
199203        const  int  ib    = k_KQ /  QI8_1;
@@ -208,24 +212,24 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
208212        v |= (vh << 18 ) & 0x00100000 ; //  2 -> 20
209213        v |= (vh << 25 ) & 0x10000000 ; //  3 -> 28
210214
211-         const  int  u = Q_q8[k_KQ_0/WARP_SIZE ];
215+         const  int  u = Q_q8[k_KQ_0/warp_size ];
212216
213217        const  int  sumi = ggml_cuda_dp4a (v, u, 0 );
214218
215219#ifdef  FP16_AVAILABLE
216220        if  (std::is_same<T, half>::value) {
217221            const  half2  * Q_ds = (const  half2  *) Q_ds_v;
218222
219-             const  half2 d5d8_m5s8 = K_q5_1[ib].dm  * Q_ds[k_KQ_0/WARP_SIZE ];
223+             const  half2 d5d8_m5s8 = K_q5_1[ib].dm  * Q_ds[k_KQ_0/warp_size ];
220224            const  half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2 (sumi, 1 .0f /QI8_1);
221225            sum += (T) (__low2half (sumid5d8_m5s8scaled) + __high2half (sumid5d8_m5s8scaled));
222226        } else 
223227#endif  //  FP16_AVAILABLE
224228        {
225229            const  float2  * Q_ds = (const  float2  *) Q_ds_v;
226230
227-             const  float  sumid5d8   =  __low2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].x  * sumi;
228-             const  float  m5s8scaled = __high2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].y  / QI8_1;
231+             const  float  sumid5d8   =  __low2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].x  * sumi;
232+             const  float  m5s8scaled = __high2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].y  / QI8_1;
229233
230234            sum += (T) (sumid5d8 + m5s8scaled);
231235        }
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
239243    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
240244
241245    const  block_q8_0 * K_q8_0 = (const  block_q8_0 *) K_c;
246+     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
242247    GGML_UNUSED (Q_v);
243248
244249    T sum = 0 .0f ;
245250
246251#pragma  unroll
247-     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
252+     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
248253        const  int  k_KQ = k_KQ_0 + threadIdx .x ;
249254
250255        const  int  ib  = k_KQ / QI8_0;
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
255260        T Q_d;
256261        if  (std::is_same<T, half>::value) {
257262            const  half2  * Q_ds = (const  half2  *) Q_ds_v;
258-             Q_d = __low2half (Q_ds[k_KQ_0/WARP_SIZE ]);
263+             Q_d = __low2half (Q_ds[k_KQ_0/warp_size ]);
259264        } else  {
260265            const  float2  * Q_ds = (const  float2  *) Q_ds_v;
261-             Q_d = Q_ds[k_KQ_0/WARP_SIZE ].x ;
266+             Q_d = Q_ds[k_KQ_0/warp_size ].x ;
262267        }
263268
264-         sum += vec_dot_q8_0_q8_1_impl<T, 1 >(&v, &Q_q8[k_KQ_0/WARP_SIZE ], K_q8_0[ib].d , Q_d);
269+         sum += vec_dot_q8_0_q8_1_impl<T, 1 >(&v, &Q_q8[k_KQ_0/warp_size ], K_q8_0[ib].d , Q_d);
265270    }
266271
267272    return  sum;
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272277    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8 , const  void  * __restrict__  Q_ds_v) {
273278
274279    const  half2 * K_h2 = (const  half2 *) K_c;
280+     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
275281    GGML_UNUSED (Q_q8);
276282    GGML_UNUSED (Q_ds_v);
277283
@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
282288        half2 sum2 = make_half2 (0 .0f , 0 .0f );
283289
284290#pragma  unroll
285-         for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += WARP_SIZE ) {
291+         for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += warp_size ) {
286292            const  int  k_KQ = k_KQ_0 + threadIdx .x ;
287293
288294            const  half2 K_ik = K_h2[k_KQ];
289-             sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE ];
295+             sum2 += K_ik * Q_h2[k_KQ_0/warp_size ];
290296        }
291297
292298        return  __low2half (sum2) + __high2half (sum2);
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
298304    float  sum = 0 .0f ;
299305
300306#pragma  unroll
301-     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += WARP_SIZE ) {
307+     for  (int  k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += warp_size ) {
302308        const  int  k_KQ = k_KQ_0 + threadIdx .x ;
303309
304310        const  half2 K_ik = K_h2[k_KQ];
305-         sum +=  __low2float (K_ik) * Q_f2[k_KQ_0/WARP_SIZE ].x ;
306-         sum += __high2float (K_ik) * Q_f2[k_KQ_0/WARP_SIZE ].y ;
311+         sum +=  __low2float (K_ik) * Q_f2[k_KQ_0/warp_size ].x ;
312+         sum += __high2float (K_ik) * Q_f2[k_KQ_0/warp_size ].y ;
307313    }
308314
309315    return  sum;
@@ -698,6 +704,8 @@ void launch_fattn(
698704
699705    GGML_ASSERT (Q->ne [3 ] == 1 );
700706
707+     const  int  warp_size = ggml_cuda_info ().devices [ctx.device ].warp_size ;
708+ 
701709    ggml_cuda_pool & pool = ctx.pool ();
702710    cudaStream_t main_stream = ctx.stream ();
703711    const  int  id  = ggml_cuda_get_device ();
@@ -750,7 +758,7 @@ void launch_fattn(
750758    const  int  ntiles_x = ((Q->ne [1 ] + ncols1 - 1 ) / ncols1);
751759    const  int  ntiles_total = ntiles_x * (Q->ne [2 ] / ncols2) * Q->ne [3 ];
752760
753-     const  dim3  block_dim (WARP_SIZE , nwarps, 1 );
761+     const  dim3  block_dim (warp_size , nwarps, 1 );
754762    dim3  blocks_num;
755763    if  (parallel_blocks == 0 ) {
756764        //  For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
@@ -796,6 +804,8 @@ void launch_fattn(
796804    const  float  m0 = powf (2 .0f , -(max_bias       ) / n_head_log2);
797805    const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
798806
807+     GGML_ASSERT (block_dim.x  % warp_size == 0 );
808+     GGML_ASSERT (!GGML_CUDA_CC_IS_AMD (cc) || block_dim.x  * block_dim.y  <= 4  * (unsigned  int )warp_size);
799809    fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>> (
800810        (const  char  *) Q->data ,
801811        K_data,
0 commit comments