@@ -54,6 +54,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
5454inline __device__ void compute_attn_1rowblock (const Params ¶ms, const int bidb, const int bidh, const int m_block) {
5555
5656 using Element = typename Kernel_traits::Element;
57+ using ElementMask = typename Kernel_traits::ElementMask;
5758 using ElementAccum = typename Kernel_traits::ElementAccum;
5859 using index_t = typename Kernel_traits::index_t ;
5960
@@ -169,7 +170,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
169170 make_coord (_, 0 )
170171 ); // (kBlockN, kHeadDim, nblocksN)
171172 Tensor mMask = make_tensor (
172- make_gmem_ptr (reinterpret_cast <const bool *>(params.mask_ptr ) + binfo.mask_offset (params.mask_batch_stride , params.mask_row_stride , bidb)),
173+ make_gmem_ptr (reinterpret_cast <ElementMask *>(params.mask_ptr ) + binfo.mask_offset (params.mask_batch_stride , params.mask_row_stride , bidb)),
173174 make_shape (params.h_mask , binfo.actual_seqlen_q , binfo.actual_seqlen_k ),
174175 make_stride (params.mask_head_stride , params.mask_row_stride , _1{})
175176 );
@@ -216,13 +217,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
216217 sV .data ().get (),
217218 typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}
218219 );
219- Tensor sMask = make_tensor (
220+ Tensor sMaskPlace = make_tensor (
220221 Has_mask ? sV .data () + size (sV ) : sV .data (),
221- typename Kernel_traits::SmemLayoutAtomPS{}
222+ typename Kernel_traits::SmemLayoutPS{}
223+ ); // For pointers alignment only
224+ Tensor sMask = make_tensor (
225+ make_smem_ptr (reinterpret_cast <ElementMask*>(sMaskPlace .data ().get ())),
226+ typename Kernel_traits::SmemLayoutPS{}
222227 );
223228 Tensor sBias = make_tensor (
224- Has_bias ? (Has_mask ? sMask .data () + size (sMask ) : sV .data () + size (sV )) : sV .data (),
225- typename Kernel_traits::SmemLayoutAtomPS {}
229+ Has_bias ? (Has_mask ? sMaskPlace .data () + size (sMaskPlace ) : sV .data () + size (sV )) : sV .data (),
230+ typename Kernel_traits::SmemLayoutPS {}
226231 );
227232
228233 // Global to Shared Memory operation
@@ -364,25 +369,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
364369 }
365370
366371 if constexpr (Has_mask) {
367- // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
368- // gmem_tiled_copy_Mask,
369- // tMaskgMask(_, _, _, n_block), tMasksMask,
370- // tMaskcMask, tMaskpMask,
371- // binfo.actual_seqlen_q - m_block * kBlockM
372- // );
373- // cute::cp_async_fence();
374- // FLASH_NAMESPACE::cp_async_wait<0>();
375- // // Do OR-reduce on the mask to see if any active threads
376-
377-
378- FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
372+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
379373 gmem_tiled_copy_Mask,
380374 tMaskgMask (_, _, _, n_block), tMasksMask,
381- any_active,
382375 tMaskcMask, tMaskpMask,
383376 binfo.actual_seqlen_q - m_block * kBlockM
384377 );
385- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
378+ __syncthreads ();
379+ // Do OR-reduce on the mask to see if any active threads for current iteration.
380+ FLASH_NAMESPACE::mask_or_reduce (
381+ tMasksMask,
382+ any_active,
383+ smem_thr_copy_Mask
384+ );
386385 }
387386
388387 // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
@@ -394,7 +393,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
394393 binfo.actual_seqlen_k - n_block * kBlockN
395394 );
396395 if constexpr (Has_bias) {
397- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
396+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
398397 gmem_tiled_copy_Bias,
399398 tBiasgBias (_, _, _, n_block), tBiassBias,
400399 tBiascBias, tBiaspBias,
@@ -524,24 +523,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
524523
525524 if (n_block > n_block_min) {
526525 if constexpr (Has_mask) {
527- // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
528- // gmem_tiled_copy_Mask,
529- // tMaskgMask(_, _, _, n_block - 1), tMasksMask,
530- // tMaskcMask, tMaskpMask,
531- // binfo.actual_seqlen_q - m_block * kBlockM
532- // );
533- // cute::cp_async_fence();
534- // FLASH_NAMESPACE::cp_async_wait<0>();
535- // // Do OR-reduce on the mask to see if any active threads for next iteration.
536-
537- FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
526+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
538527 gmem_tiled_copy_Mask,
539528 tMaskgMask (_, _, _, n_block - 1 ), tMasksMask,
540- any_active_next,
541529 tMaskcMask, tMaskpMask,
542530 binfo.actual_seqlen_q - m_block * kBlockM
543531 );
544- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
532+ __syncthreads ();
533+ // Do OR-reduce on the mask to see if any active threads for next iteration.
534+ FLASH_NAMESPACE::mask_or_reduce (
535+ tMasksMask,
536+ any_active_next,
537+ smem_thr_copy_Mask
538+ );
545539 }
546540
547541 if (any_active_next) {
@@ -551,7 +545,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
551545 tKVcKV, tKVpKV
552546 );
553547 if constexpr (Has_bias) {
554- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
548+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
555549 gmem_tiled_copy_Bias,
556550 tBiasgBias (_, _, _, n_block - 1 ), tBiassBias,
557551 tBiascBias, tBiaspBias,
@@ -684,24 +678,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
684678
685679 if (n_block > n_block_min) {
686680 if constexpr (Has_mask) {
687- // FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
688- // gmem_tiled_copy_Mask,
689- // tMaskgMask(_, _, _, n_block - 1), tMasksMask,
690- // tMaskcMask, tMaskpMask,
691- // binfo.actual_seqlen_q - m_block * kBlockM
692- // );
693- // cute::cp_async_fence();
694- // FLASH_NAMESPACE::cp_async_wait<0>();
695- // // Do OR-reduce on the mask to see if any active threads for next iteration.
696-
697- FLASH_NAMESPACE::copy_mask_with_or_reduce</* Is_even_MN=*/ true , /* Clear_OOB_MN=*/ false , /* To_type=*/ Element>(
681+ FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true , /* Clear_OOB_MN=*/ false >(
698682 gmem_tiled_copy_Mask,
699683 tMaskgMask (_, _, _, n_block - 1 ), tMasksMask,
700- any_active_next,
701684 tMaskcMask, tMaskpMask,
702685 binfo.actual_seqlen_q - m_block * kBlockM
703686 );
704- // We don't need to syncthreads here because copy_mask is already or_syncthreads
687+ __syncthreads ();
688+ // Do OR-reduce on the mask to see if any active threads for next iteration.
689+ FLASH_NAMESPACE::mask_or_reduce (
690+ tMasksMask,
691+ any_active_next,
692+ smem_thr_copy_Mask
693+ );
705694 }
706695
707696 if (any_active_next) {
@@ -711,7 +700,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
711700 tKVcKV, tKVpKV
712701 );
713702 if constexpr (Has_bias) {
714- FLASH_NAMESPACE::copy_bias </* Is_even_MN=*/ true , /* Clear_OOB_MN=*/ false >(
703+ FLASH_NAMESPACE::copy_MN </* Is_even_MN=*/ true , /* Clear_OOB_MN=*/ false >(
715704 gmem_tiled_copy_Bias,
716705 tBiasgBias (_, _, _, n_block - 1 ), tBiassBias,
717706 tBiascBias, tBiaspBias,
@@ -834,6 +823,7 @@ template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, b
834823inline __device__ void compute_attn_1rowblock_splitkv (const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
835824
836825 using Element = typename Kernel_traits::Element;
826+ using ElementMask = typename Kernel_traits::ElementMask;
837827 using ElementAccum = typename Kernel_traits::ElementAccum;
838828 using index_t = typename Kernel_traits::index_t ;
839829
@@ -970,7 +960,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
970960 make_stride (params.v_row_stride , _1{})
971961 );
972962 Tensor gMask = make_tensor (
973- make_gmem_ptr (reinterpret_cast <const bool *>(params.mask_ptr ) + col_offset_mask),
963+ make_gmem_ptr (reinterpret_cast <ElementMask *>(params.mask_ptr ) + col_offset_mask),
974964 Shape<Int<kBlockM >, Int<kBlockN >>{},
975965 make_stride (params.mask_row_stride , _1{})
976966 );
@@ -1001,13 +991,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
1001991 sV .data ().get (),
1002992 typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}
1003993 );
1004- Tensor sMask = make_tensor (
994+ Tensor sMaskPlace = make_tensor (
1005995 Has_mask ? sV .data () + size (sV ) : sV .data (),
1006- typename Kernel_traits::SmemLayoutAtomPS{}
996+ typename Kernel_traits::SmemLayoutPS{}
997+ ); // For pointers alignment only
998+ Tensor sMask = make_tensor (
999+ make_smem_ptr (reinterpret_cast <ElementMask*>(sMaskPlace .data ().get ())),
1000+ typename Kernel_traits::SmemLayoutPS{}
10071001 );
10081002 Tensor sBias = make_tensor (
1009- Has_bias ? (Has_mask ? sMask .data () + size (sMask ) : sV .data () + size (sV )) : sV .data (),
1010- typename Kernel_traits::SmemLayoutAtomPS {}
1003+ Has_bias ? (Has_mask ? sMaskPlace .data () + size (sMaskPlace ) : sV .data () + size (sV )) : sV .data (),
1004+ typename Kernel_traits::SmemLayoutPS {}
10111005 );
10121006
10131007 // Global to Shared Memory operation
@@ -1115,24 +1109,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
11151109 );
11161110
11171111 if constexpr (Has_mask) {
1118- // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
1119- // gmem_tiled_copy_Mask,
1120- // tMaskgMask, tMasksMask,
1121- // tMaskcMask, tMaskpMask,
1122- // binfo.actual_seqlen_q - m_block * kBlockM
1123- // );
1124- // cute::cp_async_fence();
1125- // FLASH_NAMESPACE::cp_async_wait<0>();
1126- // // Do OR-reduce on the mask to see if any active threads
1127-
1128- FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
1112+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
11291113 gmem_tiled_copy_Mask,
11301114 tMaskgMask, tMasksMask,
1131- any_active,
11321115 tMaskcMask, tMaskpMask,
11331116 binfo.actual_seqlen_q - m_block * kBlockM
11341117 );
1135- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
1118+ __syncthreads ();
1119+ // Do OR-reduce on the mask to see if any active threads for current iteration.
1120+ FLASH_NAMESPACE::mask_or_reduce (
1121+ tMasksMask,
1122+ any_active,
1123+ smem_thr_copy_Mask
1124+ );
11361125 }
11371126
11381127 // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
@@ -1144,7 +1133,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
11441133 binfo.actual_seqlen_k - n_block * kBlockN
11451134 );
11461135 if constexpr (Has_bias) {
1147- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
1136+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
11481137 gmem_tiled_copy_Bias,
11491138 tBiasgBias, tBiassBias,
11501139 tBiascBias, tBiaspBias,
@@ -1305,24 +1294,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
13051294 }
13061295
13071296 if constexpr (Has_mask) {
1308- // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
1309- // gmem_tiled_copy_Mask,
1310- // tMaskgMask, tMasksMask,
1311- // tMaskcMask, tMaskpMask,
1312- // binfo.actual_seqlen_q - m_block * kBlockM
1313- // );
1314- // cute::cp_async_fence();
1315- // FLASH_NAMESPACE::cp_async_wait<0>();
1316- // // Do OR-reduce on the mask to see if any active threads for next iteration.
1317-
1318- FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
1297+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true >(
13191298 gmem_tiled_copy_Mask,
13201299 tMaskgMask, tMasksMask,
1321- any_active_next,
13221300 tMaskcMask, tMaskpMask,
13231301 binfo.actual_seqlen_q - m_block * kBlockM
13241302 );
1325- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
1303+ __syncthreads ();
1304+ // Do OR-reduce on the mask to see if any active threads for next iteration.
1305+ FLASH_NAMESPACE::mask_or_reduce (
1306+ tMasksMask,
1307+ any_active_next,
1308+ smem_thr_copy_Mask
1309+ );
13261310 }
13271311
13281312 if (any_active_next) {
@@ -1332,9 +1316,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
13321316 tKVcKV, tKVpKV
13331317 );
13341318 if constexpr (Has_bias) {
1335- FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
1319+ FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
13361320 gmem_tiled_copy_Bias,
1337- tBiasgBias, tBiassBias,
1321+ tBiasgBias, tBiassBias,
13381322 tBiascBias, tBiaspBias,
13391323 binfo.actual_seqlen_q - m_block * kBlockM
13401324 );
@@ -1492,24 +1476,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
14921476 }
14931477
14941478 if constexpr (Has_mask) {
1495- // FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true, /*Clear_OOB_MN=*/false>(
1496- // gmem_tiled_copy_Mask,
1497- // tMaskgMask, tMasksMask,
1498- // tMaskcMask, tMaskpMask,
1499- // binfo.actual_seqlen_q - m_block * kBlockM
1500- // );
1501- // cute::cp_async_fence();
1502- // FLASH_NAMESPACE::cp_async_wait<0>();
1503- // // Do OR-reduce on the mask to see if any active threads for next iteration.
1504-
1505- FLASH_NAMESPACE::copy_mask_with_or_reduce</* Is_even_MN=*/ true , /* Clear_OOB_MN=*/ false , /* To_type=*/ Element>(
1479+ FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true , /* Clear_OOB_MN=*/ false >(
15061480 gmem_tiled_copy_Mask,
15071481 tMaskgMask, tMasksMask,
1508- any_active_next,
15091482 tMaskcMask, tMaskpMask,
15101483 binfo.actual_seqlen_q - m_block * kBlockM
15111484 );
1512- // We don't need to syncthreads here because copy_mask is already or_syncthreads.
1485+ __syncthreads ();
1486+ // Do OR-reduce on the mask to see if any active threads for next iteration.
1487+ FLASH_NAMESPACE::mask_or_reduce (
1488+ tMasksMask,
1489+ any_active_next,
1490+ smem_thr_copy_Mask
1491+ );
15131492 }
15141493
15151494 if (any_active_next) {
@@ -1519,9 +1498,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
15191498 tKVcKV, tKVpKV
15201499 );
15211500 if constexpr (Has_bias) {
1522- FLASH_NAMESPACE::copy_bias </* Is_even_MN=*/ true , /* Clear_OOB_MN=*/ false >(
1501+ FLASH_NAMESPACE::copy_MN </* Is_even_MN=*/ true , /* Clear_OOB_MN=*/ false >(
15231502 gmem_tiled_copy_Bias,
1524- tBiasgBias, tBiassBias,
1503+ tBiasgBias, tBiassBias,
15251504 tBiascBias, tBiaspBias,
15261505 binfo.actual_seqlen_q - m_block * kBlockM
15271506 );
0 commit comments