@@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
785785 const half2 * const __restrict__ K_h2,
786786 const half2 * const __restrict__ V_h2,
787787 const half2 * const __restrict__ mask_h2,
788+ const float * const __restrict__ sinks_f,
788789 float2 * const __restrict__ dstk,
789790 float2 * const __restrict__ dstk_fixup,
790791 const float scale,
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
957958 }
958959 }
959960
961+ // If attention sinks are used, potentially re-scale if KQ_max is small.
962+ // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
963+ // so it's being done unconditionally for every thread.
964+ if (!is_fixup && (np == 1 || threadIdx .y % np == 0 ) && sinks_f) {
965+ float KQ_max_scale[cols_per_thread];
966+ #pragma unroll
967+ for (int col = 0 ; col < cols_per_thread; ++col) {
968+ static_assert (ntiles == 1 || ntiles == 2 , " ntiles > 2 not implemented" );
969+ const int jc = ntiles == 1 ? 2 *tile_C_VKQ::get_j (col/2 ) + col % 2 : tile_C_VKQ_16::get_i (col);
970+ const float sink = sinks_f[jc % ncols2];
971+
972+ const float KQ_max_new = fmaxf (KQ_max[col], sink);
973+ const float KQ_max_diff = KQ_max[col] - KQ_max_new;
974+ KQ_max_scale[col] = expf (KQ_max_diff);
975+ KQ_max[col] = KQ_max_new;
976+
977+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
978+
979+ const float KQ_max_add = expf (sink - KQ_max_new);
980+ KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
981+ }
982+
983+ if (ntiles == 1 ) {
984+ const half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale[0 ], KQ_max_scale[1 ]);
985+ #pragma unroll
986+ for (int i = 0 ; i < DV/tile_C_VKQ::I; ++i) {
987+ #pragma unroll
988+ for (int l = 0 ; l < tile_C_VKQ::ne; ++l) {
989+ VKQ_C[i].x [l] *= KQ_max_scale_h2;
990+ }
991+ }
992+ } else {
993+ #pragma unroll
994+ for (int col = 0 ; col < cols_per_thread; ++col) {
995+ const half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale[col], KQ_max_scale[col]);
996+ #pragma unroll
997+ for (int i = 0 ; i < DV/tile_C_VKQ_16::J; ++i) {
998+ #pragma unroll
999+ for (int l0 = 0 ; l0 < tile_C_VKQ_16::ne; l0 += 2 ) {
1000+ VKQ_C_16[i*ntiles/2 + col/2 ].x [l0 + col % 2 ] *= KQ_max_scale_h2;
1001+ }
1002+ }
1003+ }
1004+ }
1005+ }
1006+
9601007 // Combine VKQ accumulator values if np > 1.
9611008 // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
9621009 // So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16(
12711318
12721319 while (kbc < kbc_stop && kb0_stop == iter_k) {
12731320 const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1274- const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1275- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head ) / iter_k; // j index of current tile.
1321+ const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1322+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt ) / iter_k; // j index of current tile.
12761323
1277- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1278- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1324+ const int head0 = zt * ncols2;
1325+
1326+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1327+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
12791328 const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
12801329 (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1281- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2 ) * (DV/2 );
1330+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0 ) * (DV/2 );
12821331
1283- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1332+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1333+ const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr ;
12841334
1285- const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head , n_head_log2, m0, m1) : 1 .0f ;
1335+ const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head0 , n_head_log2, m0, m1) : 1 .0f ;
12861336
12871337 const int kb0_start_kernel = kb0_start * kb_niter;
12881338 int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
12951345 if (kb0_start == 0 ) {
12961346 constexpr bool needs_fixup = false ; // CUDA block is working on an entire tile.
12971347 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1298- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1348+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
12991349 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13001350 } else {
13011351 constexpr bool needs_fixup = true ; // CUDA block is working on the beginning of a tile.
13021352 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1303- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1353+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
13041354 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13051355 }
13061356
@@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
13161366 }
13171367
13181368 const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1319- const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1320- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1369+ const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1370+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1371+
1372+ const int head0 = zt * ncols2;
13211373
1322- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2) );
1323- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1374+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0 );
1375+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
13241376 const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
13251377 (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1326- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2 ) * (DV/2 );
1378+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0 ) * (DV/2 );
13271379
1328- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1380+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1381+ const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr ;
13291382
1330- const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head , n_head_log2, m0, m1) : 1 .0f ;
1383+ const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head0 , n_head_log2, m0, m1) : 1 .0f ;
13311384
13321385 const int kb0_start_kernel = kb0_start * kb_niter;
13331386 int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16(
13391392 constexpr bool is_fixup = true ; // Last index writes its data to fixup buffer to avoid data races with other blocks.
13401393 constexpr bool needs_fixup = false ;
13411394 flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1342- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1395+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
13431396 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
13441397#else
13451398 GGML_UNUSED (Q); GGML_UNUSED (K); GGML_UNUSED (V); GGML_UNUSED (mask); GGML_UNUSED (sinks);
0 commit comments