@@ -57,35 +57,36 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
5757 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
5858
5959 const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
6061 GGML_UNUSED (Q_v);
6162
6263 T sum = 0 .0f ;
6364
6465#pragma unroll
65- for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
66+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
6667 const int k_KQ = k_KQ_0 + threadIdx .x ;
6768
6869 const int ib = k_KQ / QI8_1;
6970 const int iqs4 = k_KQ % QI4_0;
7071 const int shift = k_KQ & (QI8_1/2 );
7172
7273 const int v = (get_int_b2 (K_q4_0[ib].qs , iqs4) >> shift) & 0x0F0F0F0F ;
73- const int u = Q_q8[k_KQ_0/WARP_SIZE ];
74+ const int u = Q_q8[k_KQ_0/warp_size ];
7475
7576 const int sumi = ggml_cuda_dp4a (v, u, 0 );
7677
7778#ifdef FP16_AVAILABLE
7879 if (std::is_same<T, half>::value) {
7980 const half2 * Q_ds = (const half2 *) Q_ds_v;
8081
81- const half2 sum2 = __half2half2 (K_q4_0[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE ];
82+ const half2 sum2 = __half2half2 (K_q4_0[ib].d ) * Q_ds[k_KQ_0/warp_size ];
8283 sum += (T) (((half) sumi)*__low2half (sum2) - __high2half (sum2) /* *8/QI8_1 == 1 */ );
8384 } else
8485#endif // FP16_AVAILABLE
8586 {
8687 const float2 * Q_ds = (const float2 *) Q_ds_v;
8788
88- sum += (T) (__half2float (K_q4_0[ib].d ) * (sumi*Q_ds[k_KQ_0/WARP_SIZE ].x - (8 /QI8_1)*Q_ds[k_KQ_0/WARP_SIZE ].y ));
89+ sum += (T) (__half2float (K_q4_0[ib].d ) * (sumi*Q_ds[k_KQ_0/warp_size ].x - (8 /QI8_1)*Q_ds[k_KQ_0/warp_size ].y ));
8990 }
9091 }
9192
@@ -97,37 +98,38 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
9798 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
9899
99100 const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
101+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
100102 GGML_UNUSED (Q_v);
101103
102104 T sum = 0 .0f ;
103105
104106#pragma unroll
105- for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
107+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
106108 const int k_KQ = k_KQ_0 + threadIdx .x ;
107109
108110 const int ib = k_KQ / QI8_1;
109111 const int iqs4 = k_KQ % QI4_1;
110112 const int shift = k_KQ & (QI8_1/2 );
111113
112114 const int v = (get_int_b4 (K_q4_1[ib].qs , iqs4) >> shift) & 0x0F0F0F0F ;
113- const int u = Q_q8[k_KQ_0/WARP_SIZE ];
115+ const int u = Q_q8[k_KQ_0/warp_size ];
114116
115117 const int sumi = ggml_cuda_dp4a (v, u, 0 );
116118
117119#ifdef FP16_AVAILABLE
118120 if (std::is_same<T, half>::value) {
119121 const half2 * Q_ds = (const half2 *) Q_ds_v;
120122
121- const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE ];
123+ const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size ];
122124 const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2 (sumi, 1 .0f /QI8_1);
123125 sum += (T) (__low2half (sumid4d8_m4s8scaled) + __high2half (sumid4d8_m4s8scaled));
124126 } else
125127#endif // FP16_AVAILABLE
126128 {
127129 const float2 * Q_ds = (const float2 *) Q_ds_v;
128130
129- const float sumid4d8 = __low2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].x * sumi;
130- const float m4s8scaled = __high2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].y / QI8_1;
131+ const float sumid4d8 = __low2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].x * sumi;
132+ const float m4s8scaled = __high2float (K_q4_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].y / QI8_1;
131133
132134 sum += (T) (sumid4d8 + m4s8scaled);
133135 }
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141143 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
142144
143145 const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
144147 GGML_UNUSED (Q_v);
145148
146149 T sum = 0 .0f ;
147150
148151#pragma unroll
149- for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
152+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
150153 const int k_KQ = k_KQ_0 + threadIdx .x ;
151154
152155 const int ib = k_KQ / QI8_1;
@@ -161,22 +164,22 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
161164 v |= (vh << 18 ) & 0x00100000 ; // 2 -> 20
162165 v |= (vh << 25 ) & 0x10000000 ; // 3 -> 28
163166
164- const int u = Q_q8[k_KQ_0/WARP_SIZE ];
167+ const int u = Q_q8[k_KQ_0/warp_size ];
165168
166169 const int sumi = ggml_cuda_dp4a (v, u, 0 );
167170
168171#ifdef FP16_AVAILABLE
169172 if (std::is_same<T, half>::value) {
170173 const half2 * Q_ds = (const half2 *) Q_ds_v;
171174
172- const half2 sum2 = __half2half2 (K_q5_0[ib].d ) * Q_ds[k_KQ_0/WARP_SIZE ];
175+ const half2 sum2 = __half2half2 (K_q5_0[ib].d ) * Q_ds[k_KQ_0/warp_size ];
173176 sum += (T) (((half) sumi)*__low2half (sum2) - __high2half (sum2)*__float2half (2 .0f )) /* *16/QI8_1 == 2 */ ;
174177 } else
175178#endif // FP16_AVAILABLE
176179 {
177180 const float2 * Q_ds = (const float2 *) Q_ds_v;
178181
179- sum += (T) (__half2float (K_q5_0[ib].d ) * (sumi*Q_ds[k_KQ_0/WARP_SIZE ].x - (16 /QI8_1)*Q_ds[k_KQ_0/WARP_SIZE ].y ));
182+ sum += (T) (__half2float (K_q5_0[ib].d ) * (sumi*Q_ds[k_KQ_0/warp_size ].x - (16 /QI8_1)*Q_ds[k_KQ_0/warp_size ].y ));
180183 }
181184 }
182185
@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
188191 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
189192
190193 const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
194+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
191195 GGML_UNUSED (Q_v);
192196
193197 T sum = 0 .0f ;
194198
195199#pragma unroll
196- for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
200+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
197201 const int k_KQ = k_KQ_0 + threadIdx .x ;
198202
199203 const int ib = k_KQ / QI8_1;
@@ -208,24 +212,24 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
208212 v |= (vh << 18 ) & 0x00100000 ; // 2 -> 20
209213 v |= (vh << 25 ) & 0x10000000 ; // 3 -> 28
210214
211- const int u = Q_q8[k_KQ_0/WARP_SIZE ];
215+ const int u = Q_q8[k_KQ_0/warp_size ];
212216
213217 const int sumi = ggml_cuda_dp4a (v, u, 0 );
214218
215219#ifdef FP16_AVAILABLE
216220 if (std::is_same<T, half>::value) {
217221 const half2 * Q_ds = (const half2 *) Q_ds_v;
218222
219- const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE ];
223+ const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size ];
220224 const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2 (sumi, 1 .0f /QI8_1);
221225 sum += (T) (__low2half (sumid5d8_m5s8scaled) + __high2half (sumid5d8_m5s8scaled));
222226 } else
223227#endif // FP16_AVAILABLE
224228 {
225229 const float2 * Q_ds = (const float2 *) Q_ds_v;
226230
227- const float sumid5d8 = __low2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].x * sumi;
228- const float m5s8scaled = __high2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/WARP_SIZE ].y / QI8_1;
231+ const float sumid5d8 = __low2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].x * sumi;
232+ const float m5s8scaled = __high2float (K_q5_1[ib].dm )*Q_ds[k_KQ_0/warp_size ].y / QI8_1;
229233
230234 sum += (T) (sumid5d8 + m5s8scaled);
231235 }
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
239243 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
240244
241245 const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
246+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
242247 GGML_UNUSED (Q_v);
243248
244249 T sum = 0 .0f ;
245250
246251#pragma unroll
247- for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += WARP_SIZE ) {
252+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/sizeof (int ); k_KQ_0 += warp_size ) {
248253 const int k_KQ = k_KQ_0 + threadIdx .x ;
249254
250255 const int ib = k_KQ / QI8_0;
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
255260 T Q_d;
256261 if (std::is_same<T, half>::value) {
257262 const half2 * Q_ds = (const half2 *) Q_ds_v;
258- Q_d = __low2half (Q_ds[k_KQ_0/WARP_SIZE ]);
263+ Q_d = __low2half (Q_ds[k_KQ_0/warp_size ]);
259264 } else {
260265 const float2 * Q_ds = (const float2 *) Q_ds_v;
261- Q_d = Q_ds[k_KQ_0/WARP_SIZE ].x ;
266+ Q_d = Q_ds[k_KQ_0/warp_size ].x ;
262267 }
263268
264- sum += vec_dot_q8_0_q8_1_impl<T, 1 >(&v, &Q_q8[k_KQ_0/WARP_SIZE ], K_q8_0[ib].d , Q_d);
269+ sum += vec_dot_q8_0_q8_1_impl<T, 1 >(&v, &Q_q8[k_KQ_0/warp_size ], K_q8_0[ib].d , Q_d);
265270 }
266271
267272 return sum;
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272277 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
273278
274279 const half2 * K_h2 = (const half2 *) K_c;
280+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
275281 GGML_UNUSED (Q_q8);
276282 GGML_UNUSED (Q_ds_v);
277283
@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
282288 half2 sum2 = make_half2 (0 .0f , 0 .0f );
283289
284290#pragma unroll
285- for (int k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += WARP_SIZE ) {
291+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += warp_size ) {
286292 const int k_KQ = k_KQ_0 + threadIdx .x ;
287293
288294 const half2 K_ik = K_h2[k_KQ];
289- sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE ];
295+ sum2 += K_ik * Q_h2[k_KQ_0/warp_size ];
290296 }
291297
292298 return __low2half (sum2) + __high2half (sum2);
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
298304 float sum = 0 .0f ;
299305
300306#pragma unroll
301- for (int k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += WARP_SIZE ) {
307+ for (int k_KQ_0 = 0 ; k_KQ_0 < D/2 ; k_KQ_0 += warp_size ) {
302308 const int k_KQ = k_KQ_0 + threadIdx .x ;
303309
304310 const half2 K_ik = K_h2[k_KQ];
305- sum += __low2float (K_ik) * Q_f2[k_KQ_0/WARP_SIZE ].x ;
306- sum += __high2float (K_ik) * Q_f2[k_KQ_0/WARP_SIZE ].y ;
311+ sum += __low2float (K_ik) * Q_f2[k_KQ_0/warp_size ].x ;
312+ sum += __high2float (K_ik) * Q_f2[k_KQ_0/warp_size ].y ;
307313 }
308314
309315 return sum;
@@ -698,6 +704,8 @@ void launch_fattn(
698704
699705 GGML_ASSERT (Q->ne [3 ] == 1 );
700706
707+ const int warp_size = ggml_cuda_info ().devices [ctx.device ].warp_size ;
708+
701709 ggml_cuda_pool & pool = ctx.pool ();
702710 cudaStream_t main_stream = ctx.stream ();
703711 const int id = ggml_cuda_get_device ();
@@ -750,7 +758,7 @@ void launch_fattn(
750758 const int ntiles_x = ((Q->ne [1 ] + ncols1 - 1 ) / ncols1);
751759 const int ntiles_total = ntiles_x * (Q->ne [2 ] / ncols2) * Q->ne [3 ];
752760
753- const dim3 block_dim (WARP_SIZE , nwarps, 1 );
761+ const dim3 block_dim (warp_size , nwarps, 1 );
754762 dim3 blocks_num;
755763 if (parallel_blocks == 0 ) {
756764 // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
@@ -796,6 +804,8 @@ void launch_fattn(
796804 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
797805 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
798806
807+ GGML_ASSERT (block_dim.x % warp_size == 0 );
808+ GGML_ASSERT (!GGML_CUDA_CC_IS_AMD (cc) || block_dim.x * block_dim.y <= 4 * (unsigned int )warp_size);
799809 fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>> (
800810 (const char *) Q->data ,
801811 K_data,
0 commit comments