@@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
408408 const int stride_K,
409409 const int stride_V,
410410 const int stride_mask,
411- const int jt,
412411 half2 * const __restrict__ tile_Q,
413412 half2 * const __restrict__ tile_K,
414413 half2 * const __restrict__ tile_V,
@@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
455454 cp_async_wait_all ();
456455 __syncthreads ();
457456 flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458- (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
457+ (V_h2 + int64_t ( k_VKQ_0) *stride_V, tile_V, nbatch_V2, stride_V);
459458 } else {
460459 constexpr bool use_cp_async = nstages == 1 ;
461460 if (ncols2 > 1 || mask_h2) {
@@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
471470 if (nstages <= 1 ) {
472471 constexpr bool use_cp_async = nstages == 1 ;
473472 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
474- (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
473+ (K_h2 + int64_t ( k_VKQ_0) *stride_K + k0_start, tile_K, k0_diff, stride_K);
475474 if (use_cp_async) {
476475 cp_async_wait_all ();
477476 }
@@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
715714 (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2 , tile_mask, stride_mask);
716715 }
717716 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
718- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
717+ (K_h2 + int64_t (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
719718 }
720719 }
721720
@@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
732731 if (nstages <= 1 && i0_start < reusable_cutoff) {
733732 constexpr bool use_cp_async = nstages == 1 ;
734733 flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
735- (V_h2 + k_VKQ_0*stride_V + i0_start/2 , tile_V, i0_diff/2 , stride_V);
734+ (V_h2 + int64_t ( k_VKQ_0) *stride_V + i0_start/2 , tile_V, i0_diff/2 , stride_V);
736735 if (use_cp_async) {
737736 cp_async_wait_all ();
738737 }
@@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
771770 GGML_UNUSED (mask_h2); GGML_UNUSED (dstk); GGML_UNUSED (dstk_fixup);
772771 GGML_UNUSED (scale); GGML_UNUSED (slope); GGML_UNUSED (logit_softcap);
773772 GGML_UNUSED (ne01); GGML_UNUSED (ne02); GGML_UNUSED (stride_K); GGML_UNUSED (stride_V);
774- GGML_UNUSED (stride_mask); GGML_UNUSED (jt); GGML_UNUSED (tile_K);
775- GGML_UNUSED (stride_mask); GGML_UNUSED (jt); GGML_UNUSED (tile_K);
773+ GGML_UNUSED (stride_mask); GGML_UNUSED (tile_K);
776774 GGML_UNUSED (tile_V); GGML_UNUSED (tile_mask); GGML_UNUSED (Q_B);
777775 GGML_UNUSED (VKQ_C); GGML_UNUSED (KQ_max); GGML_UNUSED (KQ_rowsum);
778776 GGML_UNUSED (kb0); GGML_UNUSED (tile_Q);
@@ -920,21 +918,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
920918 (mask_h2 + kb0_start*c::nbatch_fa/2 , tile_mask, stride_mask);
921919 }
922920 flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
923- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
921+ (K_h2 + int64_t ( kb0_start) *c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
924922 }
925923
926924 // Iterate over ne11 == previous tokens:
927925 for (int kb0 = kb0_start; kb0 < kb0_stop-1 ; ++kb0) {
928926 constexpr bool last_iter = false ;
929927 flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
930928 (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
931- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
929+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
932930 }
933931 { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
934932 constexpr bool last_iter = true ;
935933 flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
936934 (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
937- ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1 );
935+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1 );
938936 }
939937
940938 // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16(
12141212 const float m1,
12151213 const uint32_t n_head_log2,
12161214 const float logit_softcap,
1217- const int ne00,
1218- const int ne01,
1219- const int ne02,
1220- const int ne03,
1221- const int ne10,
1222- const int ne11,
1223- const int ne12,
1224- const int ne13,
1225- const int ne31,
1226- const int ne32,
1227- const int ne33,
1228- const int nb31,
1229- const int nb32,
1230- const int nb33,
1231- const int nb01,
1232- const int nb02,
1233- const int nb03,
1234- const int nb11,
1235- const int nb12,
1236- const int nb13,
1237- const int nb21,
1238- const int nb22,
1239- const int nb23,
1240- const int ne0,
1241- const int ne1,
1242- const int ne2,
1243- const int ne3) {
1215+ const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1216+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
1217+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1218+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
1219+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
1220+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
1221+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
12441222#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
12451223
12461224 // Skip unused kernel variants for faster compilation:
@@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16(
13591337 GGML_UNUSED (ne11); GGML_UNUSED (ne12); GGML_UNUSED (ne13); GGML_UNUSED (ne31); GGML_UNUSED (ne32);
13601338 GGML_UNUSED (nb31); GGML_UNUSED (nb32); GGML_UNUSED (nb01); GGML_UNUSED (nb02); GGML_UNUSED (nb03);
13611339 GGML_UNUSED (nb11); GGML_UNUSED (nb12); GGML_UNUSED (nb13); GGML_UNUSED (nb21);
1362- GGML_UNUSED (nb22); GGML_UNUSED (nb23); GGML_UNUSED (ne0); GGML_UNUSED (ne1);
1363- GGML_UNUSED (ne2); GGML_UNUSED (ne3);
1340+ GGML_UNUSED (nb22); GGML_UNUSED (nb23);
13641341 NO_DEVICE_CODE;
13651342#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
13661343}
0 commit comments