Skip to content

Commit 1cca349

Browse files
authored
[BUG FIX] Prevent mask/bias materialization; avoid OOB for irregular seqlen
2 parents cb78583 + e9f9fcc commit 1cca349

File tree

5 files changed

+454
-214
lines changed

5 files changed

+454
-214
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
272272
// Global to Shared Memory operation
273273
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
274274
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
275-
typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias;
276-
auto gmem_thr_copy_Mask = gmem_tiled_copy_MaskBias.get_thread_slice(tidx);
277-
auto gmem_thr_copy_Bias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx);
275+
typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask;
276+
typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias;
277+
auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx);
278+
auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx);
278279
using GmemTiledCopydO = std::conditional_t<Is_first, typename Kernel_traits::GmemTiledCopydO, typename Kernel_traits::GmemTiledCopyQKV>;
279280
GmemTiledCopydO gmem_tiled_copy_dO;
280281
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
@@ -417,9 +418,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
417418
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
418419
}
419420

421+
// Allocate predicate tensors for N
422+
Tensor tMaskpMask = make_tensor<bool>(make_shape(size<2>(tMasksMask)));
423+
Tensor tBiaspBias = make_tensor<bool>(make_shape(size<2>(tBiassBias)));
424+
425+
// Set predicates for n bounds
426+
if (!Is_even_MN) {
427+
#pragma unroll
428+
for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); }
429+
#pragma unroll
430+
for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); }
431+
}
432+
420433

421434
// Prologue
422435

436+
bool any_active = true; // to be updated later for current iteration
437+
bool any_active_next = true; // to be updated later for next iteration
438+
423439
// We'll advance gdQ, gdQaccum and gdBias before the 1st read/write.
424440
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
425441
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
@@ -554,24 +570,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
554570
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
555571
// // if (cute::thread(1, 0)) { print(tKrK); }
556572

557-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
558-
gmem_tiled_copy_MaskBias,
559-
tMaskgMask, tMasksMask,
560-
tMaskcMask,
561-
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
562-
);
573+
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
574+
// gmem_tiled_copy_Mask,
575+
// tMaskgMask, tMasksMask,
576+
// tMaskcMask, tMaskpMask,
577+
// binfo.actual_seqlen_q - m_block * kBlockM
578+
// );
563579
// cute::cp_async_fence();
564580
// FLASH_NAMESPACE::cp_async_wait<0>();
565-
__syncthreads();
581+
// // Do OR-reduce on the mask to see if any active threads
566582

567-
// Do OR-reduce on the mask to see if any active threads
568-
Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask);
569-
bool any_active_local = false;
570-
bool any_active_local_next = false; // to be updated later for next iteration
571-
#pragma unroll
572-
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); }
573-
bool any_active = __syncthreads_or(any_active_local);
574-
bool any_active_next = false; // to be updated later for next iteration
583+
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
584+
gmem_tiled_copy_Mask,
585+
tMaskgMask, tMasksMask,
586+
any_active,
587+
tMaskcMask, tMaskpMask,
588+
binfo.actual_seqlen_q - m_block * kBlockM
589+
);
590+
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
575591

576592
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
577593
gmem_tiled_copy_QKV,
@@ -581,12 +597,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
581597
);
582598

583599
if (any_active) {
584-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
585-
gmem_tiled_copy_MaskBias,
600+
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
601+
gmem_tiled_copy_Bias,
586602
tBiasgBias, tBiassBias,
587-
tBiascBias,
588-
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
603+
tBiascBias, tBiaspBias,
604+
binfo.actual_seqlen_q - m_block * kBlockM
589605
);
606+
// Because copy_bias currently uses scalar loads, we need to sync here.
607+
// TODO: Remove sync after fixing to vectorized loads.
608+
__syncthreads();
590609
}
591610

592611
if (!Kernel_traits::Is_V_in_regs) {
@@ -780,13 +799,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
780799
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
781800
__syncthreads();
782801
// Write dS to dBias
783-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
784-
gmem_tiled_copy_MaskBias,
802+
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/false>(
803+
gmem_tiled_copy_Bias,
785804
tBiassBias, tdBiasgdBias,
786-
tBiascBias,
787-
binfo.actual_seqlen_q - m_block * kBlockM,
788-
binfo.actual_seqlen_k - n_block * kBlockN
805+
tBiascBias, tBiaspBias,
806+
binfo.actual_seqlen_q - m_block * kBlockM
789807
);
808+
// Because copy_bias currently uses scalar loads, we need to sync here.
809+
// TODO: Remove sync after fixing to vectorized loads.
810+
__syncthreads();
790811

791812
// if (cute::thread0()) { print(tPrP); }
792813
// Layout p_l = tPrP.layout();
@@ -810,21 +831,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
810831
if (m_block > m_block_min) {
811832
// Advance gMask
812833
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
813-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
814-
gmem_tiled_copy_MaskBias,
815-
tMaskgMask, tMasksMask,
816-
tMaskcMask,
817-
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
818-
);
834+
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
835+
// gmem_tiled_copy_Mask,
836+
// tMaskgMask, tMasksMask,
837+
// tMaskcMask, tMaskpMask,
838+
// binfo.actual_seqlen_q - (m_block - 1) * kBlockM
839+
// );
819840
// FLASH_NAMESPACE::cp_async_fence();
820841
// FLASH_NAMESPACE::cp_async_wait<0>();
821-
__syncthreads();
842+
// // Do OR-reduce on the mask to see if any active threads for next iteration
822843

823-
// Do OR-reduce on the mask to see if any active threads for next iteration
824-
any_active_local_next = false;
825-
#pragma unroll
826-
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); }
827-
any_active_next = __syncthreads_or(any_active_local_next);
844+
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
845+
gmem_tiled_copy_Mask,
846+
tMaskgMask, tMasksMask,
847+
any_active_next,
848+
tMaskcMask, tMaskpMask,
849+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
850+
);
851+
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
828852

829853
// Advance gdO
830854
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
@@ -926,12 +950,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
926950
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
927951
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
928952
if (any_active_next) {
929-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
930-
gmem_tiled_copy_MaskBias,
953+
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
954+
gmem_tiled_copy_Bias,
931955
tBiasgBias, tBiassBias,
932-
tBiascBias,
933-
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
956+
tBiascBias, tBiaspBias,
957+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
934958
);
959+
// Because copy_bias currently uses scalar loads, we need to sync here.
960+
// TODO: Remove sync after fixing to vectorized loads.
961+
__syncthreads();
935962
}
936963
}
937964

0 commit comments

Comments
 (0)