@@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
425425 const half2 * const __restrict__ K_h2,
426426 const half2 * const __restrict__ V_h2,
427427 const half2 * const __restrict__ mask_h2,
428+ const float * const __restrict__ sinks_f,
428429 float2 * const __restrict__ dstk,
429430 float2 * const __restrict__ dstk_fixup,
430431 const float scale,
@@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
584585 }
585586 }
586587
588+ // If attention sinks are used, potentially re-scale if KQ_max is small.
589+ // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
590+ // so it's being done unconditionally for every thread.
591+ if (!is_fixup && (np == 1 || threadIdx .y % np == 0 ) && sinks_f) {
592+ float KQ_max_scale[cols_per_thread];
593+ #pragma unroll
594+ for (int col = 0 ; col < cols_per_thread; ++col) {
595+ static_assert (ntiles == 1 || ntiles == 2 , " ntiles > 2 not implemented" );
596+ const int jc = ntiles == 1 ? 2 *tile_C_VKQ::get_j (col/2 ) + col % 2 : tile_C_VKQ_16::get_i (col);
597+ const float sink = sinks_f[jc % ncols2];
598+
599+ const float KQ_max_new = fmaxf (KQ_max[col], sink);
600+ const float KQ_max_diff = KQ_max[col] - KQ_max_new;
601+ KQ_max_scale[col] = expf (KQ_max_diff);
602+ KQ_max[col] = KQ_max_new;
603+
604+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
605+
606+ const float KQ_max_add = expf (sink - KQ_max_new);
607+ KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
608+ }
609+
610+ if (ntiles == 1 ) {
611+ const half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale[0 ], KQ_max_scale[1 ]);
612+ #pragma unroll
613+ for (int i = 0 ; i < DV/tile_C_VKQ::I; ++i) {
614+ #pragma unroll
615+ for (int l = 0 ; l < tile_C_VKQ::ne; ++l) {
616+ VKQ_C[i].x [l] *= KQ_max_scale_h2;
617+ }
618+ }
619+ } else {
620+ #pragma unroll
621+ for (int col = 0 ; col < cols_per_thread; ++col) {
622+ const half2 KQ_max_scale_h2 = make_half2 (KQ_max_scale[col], KQ_max_scale[col]);
623+ #pragma unroll
624+ for (int i = 0 ; i < DV/tile_C_VKQ_16::J; ++i) {
625+ #pragma unroll
626+ for (int l0 = 0 ; l0 < tile_C_VKQ_16::ne; l0 += 2 ) {
627+ VKQ_C_16[i*ntiles/2 + col/2 ].x [l0 + col % 2 ] *= KQ_max_scale_h2;
628+ }
629+ }
630+ }
631+ }
632+ }
633+
587634 // Write VKQ accumulators to shared memory in column-major format.
588635 // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
589636 // Also for np > 1 the combination is done via these values in shared memory.
@@ -889,15 +936,21 @@ static __global__ void flash_attn_mma_ext_f16(
889936 int kb0_stop = min (iter_k, kb0_start + kbc_stop - kbc);
890937 while (kbc < kbc_stop && kb0_stop == iter_k) {
891938 const int channel = kbc / (iter_k*iter_j);
892- const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
939+ const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
940+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
941+
942+ const int head0 = zt * ncols2;
893943
894- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
895- const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
896- const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
944+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
945+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
897946 const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof (half2))*jt*ncols1 : nullptr ;
898- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2 );
947+ (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
948+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2 );
949+
950+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
951+ const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr ;
899952
900- const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, channel , n_head_log2, m0, m1) : 1 .0f ;
953+ const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head0 , n_head_log2, m0, m1) : 1 .0f ;
901954
902955 const int kb0_start_kernel = kb0_start * kb_niter;
903956 const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -906,12 +959,12 @@ static __global__ void flash_attn_mma_ext_f16(
906959 if (kb0_start == 0 ) {
907960 constexpr bool needs_fixup = false ; // CUDA block is working on an entire tile.
908961 flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
909- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
962+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
910963 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
911964 } else {
912965 constexpr bool needs_fixup = true ; // CUDA block is working on the beginning of a tile.
913966 flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
914- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
967+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
915968 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
916969 }
917970
@@ -927,23 +980,29 @@ static __global__ void flash_attn_mma_ext_f16(
927980 }
928981
929982 const int channel = kbc / (iter_k*iter_j);
930- const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
983+ const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
984+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
931985
932- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
933- const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
934- const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
986+ const int head0 = zt * ncols2;
987+
988+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
989+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
935990 const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof (half2))*jt*ncols1 : nullptr ;
936- float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2 );
991+ (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
992+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2 );
993+
994+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
995+ const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr ;
937996
938- const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, channel , n_head_log2, m0, m1) : 1 .0f ;
997+ const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head0 , n_head_log2, m0, m1) : 1 .0f ;
939998
940999 const int kb0_start_kernel = kb0_start * kb_niter;
9411000 const int kb0_stop_kernel = kb0_stop * kb_niter;
9421001
9431002 constexpr bool is_fixup = true ; // Last index writes its data to fixup buffer to avoid data races with other blocks.
9441003 constexpr bool needs_fixup = false ;
9451004 flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
946- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
1005+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
9471006 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
9481007#else
9491008 GGML_UNUSED (Q); GGML_UNUSED (K); GGML_UNUSED (V); GGML_UNUSED (mask);
0 commit comments