@@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
392392 }
393393}
394394
395- template <int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
395+ template <int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
396+ bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
396397static __device__ __forceinline__ void flash_attn_ext_f16_iter (
397398 const float2 * const __restrict__ Q_f2,
398399 const half2 * const __restrict__ K_h2,
@@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
922923 }
923924
924925 // Iterate over ne11 == previous tokens:
925- for (int kb0 = kb0_start; kb0 < kb0_stop-1 ; ++kb0) {
926+ int kb0 = kb0_start;
927+ for (; kb0 < kb0_stop-1 ; ++kb0) {
926928 constexpr bool last_iter = false ;
927929 flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
928930 (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
@@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
932934 constexpr bool last_iter = true ;
933935 flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
934936 (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
935- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop- 1 );
937+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0 );
936938 }
937939
938940 // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
12041206 const char * __restrict__ K,
12051207 const char * __restrict__ V,
12061208 const char * __restrict__ mask,
1209+ const int * __restrict__ KV_max,
12071210 float * __restrict__ dst,
12081211 float2 * __restrict__ dst_meta,
12091212 const float scale,
@@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16(
12801283 const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head, n_head_log2, m0, m1) : 1 .0f ;
12811284
12821285 const int kb0_start_kernel = kb0_start * kb_niter;
1283- const int kb0_stop_kernel = kb0_stop * kb_niter;
1286+ int kb0_stop_kernel = kb0_stop * kb_niter;
1287+
1288+ if (KV_max) {
1289+ kb0_stop_kernel = min (kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
1290+ }
12841291
12851292 constexpr bool is_fixup = false ; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
12861293 if (kb0_start == 0 ) {
@@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16(
13211328 const float slope = ncols2 == 1 ? get_alibi_slope (max_bias, head, n_head_log2, m0, m1) : 1 .0f ;
13221329
13231330 const int kb0_start_kernel = kb0_start * kb_niter;
1324- const int kb0_stop_kernel = kb0_stop * kb_niter;
1331+ int kb0_stop_kernel = kb0_stop * kb_niter;
1332+
1333+ if (KV_max) {
1334+ kb0_stop_kernel = min (kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
1335+ }
13251336
13261337 constexpr bool is_fixup = true ; // Last index writes its data to fixup buffer to avoid data races with other blocks.
13271338 constexpr bool needs_fixup = false ;
0 commit comments