@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
5252typedef float (*vec_dot_KQ_f32_t)(
5353 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
5454
55- template <typename T, int D>
55+ template <typename T, int D, int warp_size >
5656static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0 (
5757 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
5858
5959 const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
6160 GGML_UNUSED (Q_v);
6261
6362 T sum = 0 .0f ;
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
9392 return sum;
9493}
9594
96- template <typename T, int D>
95+ template <typename T, int D, int warp_size >
9796static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1 (
9897 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
9998
10099 const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
101- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
102100 GGML_UNUSED (Q_v);
103101
104102 T sum = 0 .0f ;
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
138136 return sum;
139137}
140138
141- template <typename T, int D>
139+ template <typename T, int D, int warp_size >
142140static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0 (
143141 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
144142
145143 const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
147144 GGML_UNUSED (Q_v);
148145
149146 T sum = 0 .0f ;
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
186183 return sum;
187184}
188185
189- template <typename T, int D>
186+ template <typename T, int D, int warp_size >
190187static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1 (
191188 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
192189
193190 const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
194- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
195191 GGML_UNUSED (Q_v);
196192
197193 T sum = 0 .0f ;
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
238234 return sum;
239235}
240236
241- template <typename T, int D>
237+ template <typename T, int D, int warp_size >
242238static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0 (
243239 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
244240
245241 const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
246- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
247242 GGML_UNUSED (Q_v);
248243
249244 T sum = 0 .0f ;
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
272267 return sum;
273268}
274269
275- template <typename T, int D>
270+ template <typename T, int D, int warp_size >
276271static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16 (
277272 const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
278273
279274 const half2 * K_h2 = (const half2 *) K_c;
280- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
281275 GGML_UNUSED (Q_q8);
282276 GGML_UNUSED (Q_ds_v);
283277
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
480474 return x[i];
481475}
482476
483- template <int D>
477+ template <int D, int warp_size = WARP_SIZE >
484478constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16 (ggml_type type_K) {
485- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
486- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
487- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
488- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
489- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
490- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
479+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size > :
480+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size > :
481+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size > :
482+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size > :
483+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size > :
484+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size > :
491485 nullptr ;
492486}
493487
494- template <int D>
488+ template <int D, int warp_size = WARP_SIZE >
495489constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32 (ggml_type type_K) {
496- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D> :
497- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D> :
498- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
499- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
500- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
501- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
490+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D, warp_size > :
491+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D, warp_size > :
492+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D, warp_size > :
493+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D, warp_size > :
494+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D, warp_size > :
495+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D, warp_size > :
502496 nullptr ;
503497}
504498
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
681675template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
682676void launch_fattn (
683677 ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
684- const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
678+ const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
679+ const int warp_size = WARP_SIZE
685680) {
686681 constexpr int ncols = ncols1 * ncols2;
687682
@@ -704,8 +699,6 @@ void launch_fattn(
704699
705700 GGML_ASSERT (Q->ne [3 ] == 1 );
706701
707- const int warp_size = ggml_cuda_info ().devices [ctx.device ].warp_size ;
708-
709702 ggml_cuda_pool & pool = ctx.pool ();
710703 cudaStream_t main_stream = ctx.stream ();
711704 const int id = ggml_cuda_get_device ();
@@ -805,7 +798,6 @@ void launch_fattn(
805798 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
806799
807800 GGML_ASSERT (block_dim.x % warp_size == 0 );
808- GGML_ASSERT (!GGML_CUDA_CC_IS_AMD (cc) || block_dim.x * block_dim.y <= 4 * (unsigned int )warp_size);
809801 fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>> (
810802 (const char *) Q->data ,
811803 K_data,
0 commit comments