@@ -169,7 +169,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
169169 make_coord (_, 0 )
170170 ); // (kBlockN, kHeadDim, nblocksN)
171171 Tensor mMask = make_tensor (
172- make_gmem_ptr (reinterpret_cast <Element *>(params.mask_ptr ) + binfo.mask_offset (params.mask_batch_stride , params.mask_row_stride , bidb)),
172+ make_gmem_ptr (reinterpret_cast <const bool *>(params.mask_ptr ) + binfo.mask_offset (params.mask_batch_stride , params.mask_row_stride , bidb)),
173173 make_shape (params.h_k , binfo.actual_seqlen_q , binfo.actual_seqlen_k ),
174174 make_stride (params.mask_head_stride , params.mask_row_stride , _1{})
175175 );
@@ -344,15 +344,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
344344 }
345345 // Reverse iteration over N blocks
346346 int n_block = n_block_max - 1 ;
347-
348- FLASH_NAMESPACE::copy_MN<Is_even_MN>(
347+
348+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN= */ true , /* Bool_to_Element= */ true , Element >(
349349 gmem_tiled_copy_MaskBias,
350350 tMaskgMask (_, _, _, n_block), tMasksMask,
351351 tMaskcMask,
352352 binfo.actual_seqlen_q - m_block * kBlockM , binfo.actual_seqlen_k - n_block * kBlockN
353353 );
354- cute::cp_async_fence ();
355- FLASH_NAMESPACE::cp_async_wait<0 >();
354+ // cute::cp_async_fence();
355+ // FLASH_NAMESPACE::cp_async_wait<0>();
356356 __syncthreads ();
357357
358358 // Do OR-reduce on the mask to see if any active threads
@@ -470,14 +470,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
470470 }
471471
472472 if (n_block > n_block_min) {
473- FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true >(
473+ FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true , /* Clear_OOB_MN= */ true , /* Bool_to_Element= */ true , Element >(
474474 gmem_tiled_copy_MaskBias,
475475 tMaskgMask (_, _, _, n_block - 1 ), tMasksMask,
476476 tMaskcMask,
477477 binfo.actual_seqlen_q - m_block * kBlockM , binfo.actual_seqlen_k - (n_block - 1 ) * kBlockN
478478 );
479- cute::cp_async_fence ();
480- FLASH_NAMESPACE::cp_async_wait<0 >();
479+ // cute::cp_async_fence();
480+ // FLASH_NAMESPACE::cp_async_wait<0>();
481481 __syncthreads ();
482482
483483 // Do OR-reduce on the mask to see if any active threads for next iteration
@@ -593,14 +593,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
593593 }
594594
595595 if (n_block > n_block_min) {
596- FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true >(
596+ FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true , /* Clear_OOB_MN= */ true , /* Bool_to_Element= */ true , Element >(
597597 gmem_tiled_copy_MaskBias,
598598 tMaskgMask (_, _, _, n_block - 1 ), tMasksMask,
599599 tMaskcMask,
600600 binfo.actual_seqlen_q - m_block * kBlockM , binfo.actual_seqlen_k - (n_block - 1 ) * kBlockN
601601 );
602- cute::cp_async_fence ();
603- FLASH_NAMESPACE::cp_async_wait<0 >();
602+ // cute::cp_async_fence();
603+ // FLASH_NAMESPACE::cp_async_wait<0>();
604604 __syncthreads ();
605605
606606 // Do OR-reduce on the mask to see if any active threads for next iteration
@@ -873,7 +873,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
873873 make_stride (params.v_row_stride , _1{})
874874 );
875875 Tensor gMask = make_tensor (
876- make_gmem_ptr (reinterpret_cast <Element *>(params.mask_ptr ) + col_offset_mask),
876+ make_gmem_ptr (reinterpret_cast <const bool *>(params.mask_ptr ) + col_offset_mask),
877877 Shape<Int<kBlockM >, Int<kBlockN >>{},
878878 make_stride (params.mask_row_stride , _1{})
879879 );
@@ -999,14 +999,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
999999
10001000 int n_block = n_block_max - 1 ;
10011001
1002- FLASH_NAMESPACE::copy_MN<Is_even_MN>(
1002+ FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN= */ true , /* Bool_to_Element= */ true , Element >(
10031003 gmem_tiled_copy_MaskBias,
10041004 tMaskgMask, tMasksMask,
10051005 tMaskcMask,
10061006 binfo.actual_seqlen_q - m_block * kBlockM , binfo.actual_seqlen_k - n_block * kBlockN
10071007 );
1008- cute::cp_async_fence ();
1009- FLASH_NAMESPACE::cp_async_wait<0 >();
1008+ // cute::cp_async_fence();
1009+ // FLASH_NAMESPACE::cp_async_wait<0>();
10101010 __syncthreads ();
10111011
10121012 // Do OR-reduce on the mask to see if any active threads for next iteration
@@ -1146,14 +1146,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
11461146 tMaskgMask.data () = tMaskgMask.data () + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur);
11471147 tBiasgBias.data () = tBiasgBias.data () + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur);
11481148 }
1149- FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true >(
1149+ FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true , /* Clear_OOB_MN= */ true , /* Bool_to_Element= */ true , Element >(
11501150 gmem_tiled_copy_MaskBias,
11511151 tMaskgMask, tMasksMask,
11521152 tMaskcMask,
11531153 binfo.actual_seqlen_q - m_block * kBlockM , binfo.actual_seqlen_k - (n_block - 1 ) * kBlockN
11541154 );
1155- cute::cp_async_fence ();
1156- FLASH_NAMESPACE::cp_async_wait<0 >();
1155+ // cute::cp_async_fence();
1156+ // FLASH_NAMESPACE::cp_async_wait<0>();
11571157 __syncthreads ();
11581158
11591159 // Do OR-reduce on the mask to see if any active threads for next iteration
@@ -1287,12 +1287,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
12871287 tMaskgMask.data () = tMaskgMask.data () + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur);
12881288 tBiasgBias.data () = tBiasgBias.data () + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur);
12891289 }
1290- FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true >(
1290+ FLASH_NAMESPACE::copy_MN</* Is_even_MN=*/ true , /* Clear_OOB_MN= */ true , /* Bool_to_Element= */ true , Element >(
12911291 gmem_tiled_copy_MaskBias,
12921292 tMaskgMask, tMasksMask,
12931293 tMaskcMask,
12941294 binfo.actual_seqlen_q - m_block * kBlockM , binfo.actual_seqlen_k - (n_block - 1 ) * kBlockN
12951295 );
1296+ // cute::cp_async_fence();
1297+ // FLASH_NAMESPACE::cp_async_wait<0>();
1298+ __syncthreads ();
12961299
12971300 // Do OR-reduce on the mask to see if any active threads for next iteration
12981301 any_active_local_next = false ;
0 commit comments