@@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
785785        const  half2  * const  __restrict__  K_h2,
786786        const  half2  * const  __restrict__  V_h2,
787787        const  half2  * const  __restrict__  mask_h2,
788+         const  float   * const  __restrict__  sinks_f,
788789        float2        * const  __restrict__  dstk,
789790        float2        * const  __restrict__  dstk_fixup,
790791        const  float  scale,
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
957958        }
958959    }
959960
961+     //  If attention sinks are used, potentially re-scale if KQ_max is small.
962+     //  Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
963+     //      so it's being done unconditionally for every thread.
964+     if  (!is_fixup && (np == 1  || threadIdx .y  % np == 0 ) && sinks_f) {
965+         float  KQ_max_scale[cols_per_thread];
966+ #pragma  unroll
967+         for  (int  col = 0 ; col < cols_per_thread; ++col) {
968+             static_assert (ntiles == 1  || ntiles == 2 , " ntiles > 2 not implemented"  );
969+             const  int  jc = ntiles == 1  ? 2 *tile_C_VKQ::get_j (col/2 ) + col % 2  : tile_C_VKQ_16::get_i (col);
970+             const  float  sink = sinks_f[jc % ncols2];
971+ 
972+             const  float  KQ_max_new = fmaxf (KQ_max[col], sink);
973+             const  float  KQ_max_diff = KQ_max[col] - KQ_max_new;
974+             KQ_max_scale[col] = expf (KQ_max_diff);
975+             KQ_max[col] = KQ_max_new;
976+ 
977+             *((uint32_t  *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
978+ 
979+             const  float  KQ_max_add = expf (sink - KQ_max_new);
980+             KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
981+         }
982+ 
983+         if  (ntiles == 1 ) {
984+             const  half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale[0 ], KQ_max_scale[1 ]);
985+ #pragma  unroll
986+             for  (int  i = 0 ; i < DV/tile_C_VKQ::I; ++i) {
987+ #pragma  unroll
988+                 for  (int  l = 0 ; l < tile_C_VKQ::ne; ++l) {
989+                     VKQ_C[i].x [l] *= KQ_max_scale_h2;
990+                 }
991+             }
992+         } else  {
993+ #pragma  unroll
994+             for  (int  col = 0 ; col < cols_per_thread; ++col) {
995+                 const  half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale[col], KQ_max_scale[col]);
996+ #pragma  unroll
997+                 for  (int  i = 0 ; i < DV/tile_C_VKQ_16::J; ++i) {
998+ #pragma  unroll
999+                     for  (int  l0 = 0 ; l0 < tile_C_VKQ_16::ne; l0 += 2 ) {
1000+                         VKQ_C_16[i*ntiles/2  + col/2 ].x [l0 + col % 2 ] *= KQ_max_scale_h2;
1001+                     }
1002+                 }
1003+             }
1004+         }
1005+     }
1006+ 
9601007    //  Combine VKQ accumulator values if np > 1.
9611008    //  It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
9621009    //  So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16(
12711318
12721319    while  (kbc < kbc_stop && kb0_stop == iter_k) {
12731320        const  int  sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1274-         const  int  head  = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1275-         const  int  jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head ) / iter_k; //  j index of current tile.
1321+         const  int  zt  = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);  //  head in units of ncols2 
1322+         const  int  jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt ) / iter_k; //  j index of current tile.
12761323
1277-         const  float2  * Q_f2    = (const  float2  *) (Q + nb03*sequence + nb02*(head*ncols2));
1278-         const  half2  * K_h2    = (const  half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1324+         const  int  head0 = zt * ncols2;
1325+ 
1326+         const  float2  * Q_f2    = (const  float2  *) (Q + nb03*sequence + nb02* head0);
1327+         const  half2  * K_h2    = (const  half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
12791328        const  half2  * mask_h2 = ncols2 == 1  && !mask ? nullptr  :
12801329            (const  half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1281-         float2        * dstk    = ((float2  *) dst) + (sequence*ne01*ne02 + head*ncols2 ) * (DV/2 );
1330+         float2        * dstk    = ((float2  *) dst) + (sequence*ne01*ne02 + head0 ) * (DV/2 );
12821331
1283-         const  half2 * V_h2 = mla ? K_h2 + (DKQ/2  - DV/2 ) : (const  half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1332+         const  half2 * V_h2 = mla ? K_h2 + (DKQ/2  - DV/2 ) : (const  half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1333+         const  float  * sinks_f = sinks ? (const  float  *) sinks + head0 : nullptr ;
12841334
1285-         const  float  slope = ncols2 == 1  ? get_alibi_slope (max_bias, head , n_head_log2, m0, m1) : 1 .0f ;
1335+         const  float  slope = ncols2 == 1  ? get_alibi_slope (max_bias, head0 , n_head_log2, m0, m1) : 1 .0f ;
12861336
12871337        const  int  kb0_start_kernel = kb0_start * kb_niter;
12881338        int        kb0_stop_kernel  = kb0_stop  * kb_niter;
@@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
12951345        if  (kb0_start == 0 ) {
12961346            constexpr  bool  needs_fixup = false ; //  CUDA block is working on an entire tile.
12971347            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1298-                 (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1348+                 (Q_f2, K_h2, V_h2, mask_h2, sinks_f,  dstk, dst_meta, scale, slope, logit_softcap,
12991349                 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13001350        } else  {
13011351            constexpr  bool  needs_fixup = true ; //  CUDA block is working on the beginning of a tile.
13021352            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1303-                 (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1353+                 (Q_f2, K_h2, V_h2, mask_h2, sinks_f,  dstk, dst_meta, scale, slope, logit_softcap,
13041354                 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13051355        }
13061356
@@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
13161366    }
13171367
13181368    const  int  sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1319-     const  int  head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1320-     const  int  jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; //  j index of current tile.
1369+     const  int  zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); //  head in units of ncols2
1370+     const  int  jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; //  j index of current tile.
1371+ 
1372+     const  int  head0 = zt * ncols2;
13211373
1322-     const  float2  * Q_f2    = (const  float2  *) (Q + nb03*sequence + nb02*(head*ncols2) );
1323-     const  half2  * K_h2    = (const  half2  *) (K + nb13*sequence + nb12*(head*ncols2  / gqa_ratio));
1374+     const  float2  * Q_f2    = (const  float2  *) (Q + nb03*sequence + nb02* head0 );
1375+     const  half2  * K_h2    = (const  half2  *) (K + nb13*sequence + nb12*(head0  / gqa_ratio));
13241376    const  half2  * mask_h2 = ncols2 == 1  && !mask ? nullptr  :
13251377        (const  half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1326-     float2        * dstk    = ((float2  *) dst) + (sequence*ne01*ne02 + head*ncols2 ) * (DV/2 );
1378+     float2        * dstk    = ((float2  *) dst) + (sequence*ne01*ne02 + head0 ) * (DV/2 );
13271379
1328-     const  half2 * V_h2 = mla ? K_h2 + (DKQ/2  - DV/2 ) : (const  half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1380+     const  half2 * V_h2 = mla ? K_h2 + (DKQ/2  - DV/2 ) : (const  half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1381+     const  float  * sinks_f = sinks ? (const  float  *) sinks + head0 : nullptr ;
13291382
1330-     const  float  slope = ncols2 == 1  ? get_alibi_slope (max_bias, head , n_head_log2, m0, m1) : 1 .0f ;
1383+     const  float  slope = ncols2 == 1  ? get_alibi_slope (max_bias, head0 , n_head_log2, m0, m1) : 1 .0f ;
13311384
13321385    const  int  kb0_start_kernel = kb0_start * kb_niter;
13331386    int        kb0_stop_kernel  = kb0_stop  * kb_niter;
@@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16(
13391392    constexpr  bool  is_fixup = true ; //  Last index writes its data to fixup buffer to avoid data races with other blocks.
13401393    constexpr  bool  needs_fixup = false ;
13411394    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1342-         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1395+         (Q_f2, K_h2, V_h2, mask_h2, sinks_f,  dstk, dst_meta, scale, slope, logit_softcap,
13431396         ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13441397#else 
13451398    GGML_UNUSED (Q); GGML_UNUSED (K); GGML_UNUSED (V); GGML_UNUSED (mask); GGML_UNUSED (sinks);
0 commit comments