@@ -272,9 +272,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
272272 // Global to Shared Memory operation
273273 typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
274274 auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice (tidx);
275- typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias;
276- auto gmem_thr_copy_Mask = gmem_tiled_copy_MaskBias.get_thread_slice (tidx);
277- auto gmem_thr_copy_Bias = gmem_tiled_copy_MaskBias.get_thread_slice (tidx);
275+ typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask;
276+ typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias;
277+ auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice (tidx);
278+ auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice (tidx);
278279 using GmemTiledCopydO = std::conditional_t <Is_first, typename Kernel_traits::GmemTiledCopydO, typename Kernel_traits::GmemTiledCopyQKV>;
279280 GmemTiledCopydO gmem_tiled_copy_dO;
280281 auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice (tidx);
@@ -417,9 +418,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
417418 for (int k = 0 ; k < size (tKVpKV); ++k) { tKVpKV (k) = get<1 >(tKVcKV (0 , 0 , k)) < params.d ; }
418419 }
419420
421+ // Allocate predicate tensors for N
422+ Tensor tMaskpMask = make_tensor<bool >(make_shape (size<2 >(tMasksMask)));
423+ Tensor tBiaspBias = make_tensor<bool >(make_shape (size<2 >(tBiassBias)));
424+
425+ // Set predicates for n bounds
426+ if (!Is_even_MN) {
427+ #pragma unroll
428+ for (int n = 0 ; n < size (tMaskpMask); ++n) { tMaskpMask (n) = get<1 >(tMaskcMask (0 , 0 , n)) < std::max (0 , binfo.actual_seqlen_k - n_block * kBlockN ); }
429+ #pragma unroll
430+ for (int n = 0 ; n < size (tBiaspBias); ++n) { tBiaspBias (n) = get<1 >(tBiascBias (0 , 0 , n)) < std::max (0 , binfo.actual_seqlen_k - n_block * kBlockN ); }
431+ }
432+
420433
421434 // Prologue
422435
436+ bool any_active = true ; // to be updated later for current iteration
437+ bool any_active_next = true ; // to be updated later for next iteration
438+
423439 // We'll advance gdQ, gdQaccum and gdBias before the 1st read/write.
424440 tdQgdQ.data () = tdQgdQ.data () + kBlockM * params.dq_row_stride ;
425441 tdQgdQaccum.data () = tdQgdQaccum.data () + kBlockM * params.h * params.d_rounded ;
@@ -554,24 +570,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
554570 // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
555571 // // if (cute::thread(1, 0)) { print(tKrK); }
556572
557- FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true , /* Bool_to_Element= */ true , Element >(
558- gmem_tiled_copy_MaskBias ,
559- tMaskgMask, tMasksMask,
560- tMaskcMask,
561- binfo.actual_seqlen_q - m_block * kBlockM , binfo. actual_seqlen_k - n_block * kBlockN
562- );
573+ // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
574+ // gmem_tiled_copy_Mask ,
575+ // tMaskgMask, tMasksMask,
576+ // tMaskcMask, tMaskpMask ,
577+ // binfo.actual_seqlen_q - m_block * kBlockM
578+ // );
563579 // cute::cp_async_fence();
564580 // FLASH_NAMESPACE::cp_async_wait<0>();
565- __syncthreads ();
581+ // // Do OR-reduce on the mask to see if any active threads
566582
567- // Do OR-reduce on the mask to see if any active threads
568- Tensor tSsMask_copy_view = smem_thr_copy_PdS. retile_S (tSsMask);
569- bool any_active_local = false ;
570- bool any_active_local_next = false ; // to be updated later for next iteration
571- # pragma unroll
572- for ( int i = 0 ; i < size (tSsMask_copy_view); ++i) { any_active_local |= ( tSsMask_copy_view (i) != Element ( 0 )); }
573- bool any_active = __syncthreads_or (any_active_local );
574- bool any_active_next = false ; // to be updated later for next iteration
583+ FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN= */ true , /* To_type= */ Element>(
584+ gmem_tiled_copy_Mask,
585+ tMaskgMask, tMasksMask,
586+ any_active,
587+ tMaskcMask, tMaskpMask,
588+ binfo. actual_seqlen_q - m_block * kBlockM
589+ );
590+ // We don't need to syncthreads here because copy_mask is already or_syncthreads.
575591
576592 FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /* Clear_OOB_MN=*/ true >(
577593 gmem_tiled_copy_QKV,
@@ -581,12 +597,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
581597 );
582598
583599 if (any_active) {
584- FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
585- gmem_tiled_copy_MaskBias ,
600+ FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
601+ gmem_tiled_copy_Bias ,
586602 tBiasgBias, tBiassBias,
587- tBiascBias,
588- binfo.actual_seqlen_q - m_block * kBlockM , binfo. actual_seqlen_k - n_block * kBlockN
603+ tBiascBias, tBiaspBias,
604+ binfo.actual_seqlen_q - m_block * kBlockM
589605 );
606+ // Because copy_bias currently uses scalar loads, we need to sync here.
607+ // TODO: Remove sync after fixing to vectorized loads.
608+ __syncthreads ();
590609 }
591610
592611 if (!Kernel_traits::Is_V_in_regs) {
@@ -780,13 +799,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
780799 cute::copy (smem_tiled_copy_PdS, tdSadS, tdSsdS);
781800 __syncthreads ();
782801 // Write dS to dBias
783- FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ false >(
784- gmem_tiled_copy_MaskBias ,
802+ FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ false >(
803+ gmem_tiled_copy_Bias ,
785804 tBiassBias, tdBiasgdBias,
786- tBiascBias,
787- binfo.actual_seqlen_q - m_block * kBlockM ,
788- binfo.actual_seqlen_k - n_block * kBlockN
805+ tBiascBias, tBiaspBias,
806+ binfo.actual_seqlen_q - m_block * kBlockM
789807 );
808+ // Because copy_bias currently uses scalar loads, we need to sync here.
809+ // TODO: Remove sync after fixing to vectorized loads.
810+ __syncthreads ();
790811
791812 // if (cute::thread0()) { print(tPrP); }
792813 // Layout p_l = tPrP.layout();
@@ -810,21 +831,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
810831 if (m_block > m_block_min) {
811832 // Advance gMask
812833 tMaskgMask.data () = tMaskgMask.data () + (-int (kBlockM * params.mask_row_stride ));
813- FLASH_NAMESPACE::copy_MN<Is_even_MN, /* Clear_OOB_MN=*/ true , /* Bool_to_Element= */ true , Element >(
814- gmem_tiled_copy_MaskBias ,
815- tMaskgMask, tMasksMask,
816- tMaskcMask,
817- binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM , binfo. actual_seqlen_k - n_block * kBlockN
818- );
834+ // FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
835+ // gmem_tiled_copy_Mask ,
836+ // tMaskgMask, tMasksMask,
837+ // tMaskcMask, tMaskpMask ,
838+ // binfo.actual_seqlen_q - (m_block - 1) * kBlockM
839+ // );
819840 // FLASH_NAMESPACE::cp_async_fence();
820841 // FLASH_NAMESPACE::cp_async_wait<0>();
821- __syncthreads ();
842+ // // Do OR-reduce on the mask to see if any active threads for next iteration
822843
823- // Do OR-reduce on the mask to see if any active threads for next iteration
824- any_active_local_next = false ;
825- #pragma unroll
826- for (int i = 0 ; i < size (tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view (i) != Element (0 )); }
827- any_active_next = __syncthreads_or (any_active_local_next);
844+ FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /* Clear_OOB_MN=*/ true , /* To_type=*/ Element>(
845+ gmem_tiled_copy_Mask,
846+ tMaskgMask, tMasksMask,
847+ any_active_next,
848+ tMaskcMask, tMaskpMask,
849+ binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM
850+ );
851+ // We don't need to syncthreads here because copy_mask is already or_syncthreads.
828852
829853 // Advance gdO
830854 tdOgdO.data () = tdOgdO.data () + (-int (kBlockM * params.do_row_stride ));
@@ -926,12 +950,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
926950 tBiasgBias.data () = tBiasgBias.data () + (-int (kBlockM * params.bias_row_stride ));
927951 tdBiasgdBias.data () = tdBiasgdBias.data () + (-int (kBlockM * params.dbias_row_stride ));
928952 if (any_active_next) {
929- FLASH_NAMESPACE::copy_MN <Is_even_MN, /* Clear_OOB_MN=*/ true >(
930- gmem_tiled_copy_MaskBias ,
953+ FLASH_NAMESPACE::copy_bias <Is_even_MN, /* Clear_OOB_MN=*/ true >(
954+ gmem_tiled_copy_Bias ,
931955 tBiasgBias, tBiassBias,
932- tBiascBias,
933- binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM , binfo. actual_seqlen_k - n_block * kBlockN
956+ tBiascBias, tBiaspBias,
957+ binfo.actual_seqlen_q - (m_block - 1 ) * kBlockM
934958 );
959+ // Because copy_bias currently uses scalar loads, we need to sync here.
960+ // TODO: Remove sync after fixing to vectorized loads.
961+ __syncthreads ();
935962 }
936963 }
937964
0 commit comments