@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
5252typedef  float  (*vec_dot_KQ_f32_t)(
5353    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8 , const  void  * __restrict__  Q_ds);
5454
55- template <typename  T, int  D>
55+ template <typename  T, int  D,  int  warp_size >
5656static  __device__  __forceinline__  T vec_dot_fattn_vec_KQ_q4_0 (
5757    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
5858
5959    const  block_q4_0 * K_q4_0 = (const  block_q4_0 *) K_c;
60-     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
6160    GGML_UNUSED (Q_v);
6261
6362    T sum = 0 .0f ;
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
9392    return  sum;
9493}
9594
96- template <typename  T, int  D>
95+ template <typename  T, int  D,  int  warp_size >
9796static  __device__  __forceinline__  T vec_dot_fattn_vec_KQ_q4_1 (
9897    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
9998
10099    const  block_q4_1 * K_q4_1 = (const  block_q4_1 *) K_c;
101-     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
102100    GGML_UNUSED (Q_v);
103101
104102    T sum = 0 .0f ;
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
138136    return  sum;
139137}
140138
141- template <typename  T, int  D>
139+ template <typename  T, int  D,  int  warp_size >
142140static  __device__  __forceinline__  T vec_dot_fattn_vec_KQ_q5_0 (
143141    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
144142
145143    const  block_q5_0 * K_q5_0 = (const  block_q5_0 *) K_c;
146-     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
147144    GGML_UNUSED (Q_v);
148145
149146    T sum = 0 .0f ;
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
186183    return  sum;
187184}
188185
189- template <typename  T, int  D>
186+ template <typename  T, int  D,  int  warp_size >
190187static  __device__  __forceinline__  T vec_dot_fattn_vec_KQ_q5_1 (
191188    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
192189
193190    const  block_q5_1 * K_q5_1 = (const  block_q5_1 *) K_c;
194-     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
195191    GGML_UNUSED (Q_v);
196192
197193    T sum = 0 .0f ;
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
238234    return  sum;
239235}
240236
241- template  <typename  T, int  D>
237+ template  <typename  T, int  D,  int  warp_size >
242238static  __device__  __forceinline__  T vec_dot_fattn_vec_KQ_q8_0 (
243239    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8, const  void  * __restrict__  Q_ds_v) {
244240
245241    const  block_q8_0 * K_q8_0 = (const  block_q8_0 *) K_c;
246-     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
247242    GGML_UNUSED (Q_v);
248243
249244    T sum = 0 .0f ;
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
272267    return  sum;
273268}
274269
275- template  <typename  T, int  D>
270+ template  <typename  T, int  D,  int  warp_size >
276271static  __device__  __forceinline__  T vec_dot_fattn_vec_KQ_f16 (
277272    const  char  * __restrict__  K_c, const  void  * __restrict__  Q_v, const  int  * __restrict__  Q_q8 , const  void  * __restrict__  Q_ds_v) {
278273
279274    const  half2 * K_h2 = (const  half2 *) K_c;
280-     constexpr  int  warp_size = ggml_cuda_get_physical_warp_size ();
281275    GGML_UNUSED (Q_q8);
282276    GGML_UNUSED (Q_ds_v);
283277
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
480474    return  x[i];
481475}
482476
483- template  <int  D>
477+ template  <int  D,  int  warp_size = WARP_SIZE >
484478constexpr  __device__  vec_dot_KQ_f16_t get_vec_dot_KQ_f16 (ggml_type type_K) {
485-     return  type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
486-         type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
487-         type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
488-         type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
489-         type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
490-         type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
479+     return  type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size > :
480+         type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size > :
481+         type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size > :
482+         type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size > :
483+         type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size > :
484+         type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size > :
491485        nullptr ;
492486}
493487
494- template  <int  D>
488+ template  <int  D,  int  warp_size = WARP_SIZE >
495489constexpr  __device__  vec_dot_KQ_f32_t get_vec_dot_KQ_f32 (ggml_type type_K) {
496-     return  type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D> :
497-         type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D> :
498-         type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
499-         type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
500-         type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
501-         type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
490+     return  type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D, warp_size > :
491+         type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D, warp_size > :
492+         type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D, warp_size > :
493+         type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D, warp_size > :
494+         type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D, warp_size > :
495+         type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D, warp_size > :
502496        nullptr ;
503497}
504498
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
681675template  <int  D, int  ncols1, int  ncols2, int  parallel_blocks, int  KQ_stride>
682676void  launch_fattn (
683677    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t  fattn_kernel,
684-     const  int  nwarps, const  size_t  nbytes_shared, const  bool  need_f16_K, const  bool  need_f16_V
678+     const  int  nwarps, const  size_t  nbytes_shared, const  bool  need_f16_K, const  bool  need_f16_V,
679+     const  int  warp_size = WARP_SIZE
685680) {
686681    constexpr  int  ncols = ncols1 * ncols2;
687682
@@ -704,8 +699,6 @@ void launch_fattn(
704699
705700    GGML_ASSERT (Q->ne [3 ] == 1 );
706701
707-     const  int  warp_size = ggml_cuda_info ().devices [ctx.device ].warp_size ;
708- 
709702    ggml_cuda_pool & pool = ctx.pool ();
710703    cudaStream_t main_stream = ctx.stream ();
711704    const  int  id  = ggml_cuda_get_device ();
@@ -805,7 +798,6 @@ void launch_fattn(
805798    const  float  m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
806799
807800    GGML_ASSERT (block_dim.x  % warp_size == 0 );
808-     GGML_ASSERT (!GGML_CUDA_CC_IS_AMD (cc) || block_dim.x  * block_dim.y  <= 4  * (unsigned  int )warp_size);
809801    fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>> (
810802        (const  char  *) Q->data ,
811803        K_data,
0 commit comments