@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
3333        const  int  ne13,
3434        const  int  ne31,
3535        const  int  ne32,
36+         const  int  ne33,
3637        const  int  nb31,
3738        const  int  nb32,
39+         const  int  nb33,
3840        const  int  nb01,
3941        const  int  nb02,
4042        const  int  nb03,
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
521523template <int  D, int  ncols1, int  ncols2> //  D == head size
522524__launch_bounds__ (D, 1 )
523525static __global__ void flash_attn_stream_k_fixup(
524-         float  * __restrict__  dst, const  float2  * __restrict__  dst_fixup, const  int  ne01, const  int  ne02, const  int  ne11) {
526+         float  * __restrict__  dst, const  float2  * __restrict__  dst_fixup, const  int  ne01, const  int  ne02, const  int  ne03,  const   int   ne11) {
525527    constexpr  int  ncols = ncols1*ncols2;
526528
527529    const  int  bidx0 = blockIdx .x ;
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
535537    const  int  iter_k = ne11 / FATTN_KQ_STRIDE;
536538    const  int  iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
537539
538-     const  int  kbc0      = (bidx0 + 0 )*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
539-     const  int  kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
540+     const  int  kbc0      = (bidx0 + 0 )*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
541+     const  int  kbc0_stop = (bidx0 + 1 )*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
540542
541543    const  bool  did_not_have_any_data   = kbc0 == kbc0_stop;
542544    const  bool  wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
545547        return ;
546548    }
547549
548-     const  int  channel = kbc0 / (iter_k*iter_j);
549-     const  int  jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
550+     const  int  sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
551+     const  int  head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
552+     const  int  jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; //  j index of current tile.
550553
551554    if  (jt*ncols1 + j >= ne01) {
552555        return ;
553556    }
554557
555-     dst += jt*ne02*(ncols1*D) + channel *(ncols2*D) + (j*ne02 + c)*D + tid;
558+     dst += sequence*ne02*ne01*D +  jt*ne02*(ncols1*D) + head *(ncols2*D) + (j*ne02 + c)*D + tid;
556559
557560    //  Load the partial result that needs a fixup:
558561    float  dst_val = 0 .0f ;
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
571574    int  bidx = bidx0 - 1 ;
572575    int  kbc_stop = kbc0;
573576    while (true ) {
574-         const  int  kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim .x ;
577+         const  int  kbc = bidx*( iter_k*iter_j*(ne02/ncols2)*ne03 ) / gridDim .x ;
575578        if  (kbc == kbc_stop) { //  Did not have any data.
576579            bidx--;
577580            kbc_stop = kbc;
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
617620        const  float2  * __restrict__  VKQ_meta,
618621        float  * __restrict__  dst,
619622        const  int  parallel_blocks) {
620-     VKQ_parts += parallel_blocks*D * gridDim .z *blockIdx .x ;
621-     VKQ_meta  += parallel_blocks   * gridDim .z *blockIdx .x ;
622-     dst       +=                 D * gridDim .z *blockIdx .x ;
623+     //  Dimension 0: threadIdx.x
624+     //  Dimension 1: blockIdx.x
625+     //  Dimension 2: blockIdx.y
626+     //  Dimension 3: blockIdx.z
627+     //  Memory layout is permuted with [0, 2, 1, 3]
628+ 
629+     const  int  ne01 = gridDim .x ;
630+     const  int  ne02 = gridDim .y ;
631+ 
632+     const  int  col      = blockIdx .x ;
633+     const  int  head     = blockIdx .y ;
634+     const  int  sequence = blockIdx .z ;
635+ 
636+     const  int  j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
637+ 
638+     VKQ_parts += j_dst_unrolled * parallel_blocks*D;
639+     VKQ_meta  += j_dst_unrolled * parallel_blocks;
640+     dst       += j_dst_unrolled *                 D;
623641
624642    const  int  tid = threadIdx .x ;
625643    __builtin_assume (tid < D);
626644
627645    extern  __shared__  float2  meta[];
628646    for  (int  i = tid; i < 2 *parallel_blocks; i += D) {
629-         ((float  *) meta)[i] = ((const  float  *)VKQ_meta) [blockIdx . z *( 2 *parallel_blocks) +  i];
647+         ((float  *) meta)[i] = ((const  float  *)VKQ_meta) [i];
630648    }
631649
632650    __syncthreads ();
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
644662        const  uint32_t  ftz_mask = 0xFFFFFFFF  * (diff > SOFTMAX_FTZ_THRESHOLD);
645663        *((uint32_t  *) &KQ_max_scale) &= ftz_mask;
646664
647-         VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim . z *D +  blockIdx . z * D + tid];
665+         VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
648666        VKQ_denominator += KQ_max_scale * meta[l].y ;
649667    }
650668
651-     dst[blockIdx . z *D +  tid] = VKQ_numerator / VKQ_denominator;
669+     dst[tid] = VKQ_numerator / VKQ_denominator;
652670}
653671
654672[[noreturn]]
@@ -705,8 +723,6 @@ void launch_fattn(
705723
706724    GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0  && " Incorrect KV cache padding." 
707725
708-     GGML_ASSERT (Q->ne [3 ] == 1 );
709- 
710726    ggml_cuda_pool & pool = ctx.pool ();
711727    cudaStream_t main_stream = ctx.stream ();
712728    const  int  id  = ggml_cuda_get_device ();
@@ -853,8 +869,8 @@ void launch_fattn(
853869        scale, max_bias, m0, m1, n_head_log2, logit_softcap,
854870        Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
855871        K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
856-         mask ? mask->ne [1 ] : 0 , mask ? mask->ne [2 ] : 0 ,
857-         mask ? mask->nb [1 ] : 0 , mask ? mask->nb [2 ] : 0 ,
872+         mask ? mask->ne [1 ] : 0 , mask ? mask->ne [2 ] : 0 , mask ? mask-> ne [ 3 ] :  0 , 
873+         mask ? mask->nb [1 ] : 0 , mask ? mask->nb [2 ] : 0 , mask ? mask-> nb [ 3 ] :  0 , 
858874        Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
859875        nb11, nb12, nb13,
860876        nb21, nb22, nb23,
@@ -869,11 +885,11 @@ void launch_fattn(
869885
870886            flash_attn_stream_k_fixup<DV, ncols1, ncols2>
871887                <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>> 
872-                 ((float  *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
888+                 ((float  *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], Q-> ne [ 3 ],  K->ne [1 ]);
873889        }
874890    } else  if  (parallel_blocks > 1 ) {
875891        const  dim3  block_dim_combine (DV, 1 , 1 );
876-         const  dim3  blocks_num_combine (Q->ne [1 ], 1 , blocks_num. z );
892+         const  dim3  blocks_num_combine (Q->ne [1 ], Q-> ne [ 2 ], Q-> ne [ 3 ] );
877893        const  size_t  nbytes_shared_combine = parallel_blocks*sizeof (float2 );
878894
879895        flash_attn_combine_results<DV>
0 commit comments