Skip to content

Commit 3b7b57b

Browse files
committed
Separates mask and bias memory operations for performance
Splits the unified GmemTiledCopyMaskBias into separate GmemTiledCopyMask and GmemTiledCopyBias operations to enable independent optimization of memory access patterns. Introduces specialized copy_mask_with_or_reduce and copy_bias functions that replace generic copy_MN calls, allowing for better memory coalescing and reduced synchronization overhead. Adds predicate tensor allocation for bounds checking on the N dimension to improve memory safety and enable more efficient vectorized operations in future optimizations.
1 parent a0475b2 commit 3b7b57b

File tree

1 file changed

+70
-43
lines changed

1 file changed

+70
-43
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
272272
// Global to Shared Memory operation
273273
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
274274
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
275-
typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias;
276-
auto gmem_thr_copy_Mask = gmem_tiled_copy_MaskBias.get_thread_slice(tidx);
277-
auto gmem_thr_copy_Bias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx);
275+
typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask;
276+
typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias;
277+
auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx);
278+
auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx);
278279
using GmemTiledCopydO = std::conditional_t<Is_first, typename Kernel_traits::GmemTiledCopydO, typename Kernel_traits::GmemTiledCopyQKV>;
279280
GmemTiledCopydO gmem_tiled_copy_dO;
280281
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
@@ -417,9 +418,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
417418
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
418419
}
419420

421+
// Allocate predicate tensors for N
422+
Tensor tMaskpMask = make_tensor<bool>(make_shape(size<2>(tMasksMask)));
423+
Tensor tBiaspBias = make_tensor<bool>(make_shape(size<2>(tBiassBias)));
424+
425+
// Set predicates for n bounds
426+
if (!Is_even_MN) {
427+
#pragma unroll
428+
for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); }
429+
#pragma unroll
430+
for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); }
431+
}
432+
420433

421434
// Prologue
422435

436+
bool any_active = true; // to be updated later for current iteration
437+
bool any_active_next = true; // to be updated later for next iteration
438+
423439
// We'll advance gdQ, gdQaccum and gdBias before the 1st read/write.
424440
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
425441
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
@@ -554,24 +570,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
554570
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
555571
// // if (cute::thread(1, 0)) { print(tKrK); }
556572

557-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
558-
gmem_tiled_copy_MaskBias,
559-
tMaskgMask, tMasksMask,
560-
tMaskcMask,
561-
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
562-
);
573+
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
574+
// gmem_tiled_copy_Mask,
575+
// tMaskgMask, tMasksMask,
576+
// tMaskcMask, tMaskpMask,
577+
// binfo.actual_seqlen_q - m_block * kBlockM
578+
// );
563579
// cute::cp_async_fence();
564580
// FLASH_NAMESPACE::cp_async_wait<0>();
565-
__syncthreads();
581+
// // Do OR-reduce on the mask to see if any active threads
566582

567-
// Do OR-reduce on the mask to see if any active threads
568-
Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask);
569-
bool any_active_local = false;
570-
bool any_active_local_next = false; // to be updated later for next iteration
571-
#pragma unroll
572-
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); }
573-
bool any_active = __syncthreads_or(any_active_local);
574-
bool any_active_next = false; // to be updated later for next iteration
583+
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
584+
gmem_tiled_copy_Mask,
585+
tMaskgMask, tMasksMask,
586+
any_active,
587+
tMaskcMask, tMaskpMask,
588+
binfo.actual_seqlen_q - m_block * kBlockM
589+
);
590+
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
575591

576592
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
577593
gmem_tiled_copy_QKV,
@@ -581,12 +597,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
581597
);
582598

583599
if (any_active) {
584-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
585-
gmem_tiled_copy_MaskBias,
600+
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
601+
gmem_tiled_copy_Bias,
586602
tBiasgBias, tBiassBias,
587-
tBiascBias,
588-
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
603+
tBiascBias, tBiaspBias,
604+
binfo.actual_seqlen_q - m_block * kBlockM
589605
);
606+
// Because copy_bias currently uses scalar loads, we need to sync here.
607+
// TODO: Remove sync after fixing to vectorized loads.
608+
__syncthreads();
590609
}
591610

592611
if (!Kernel_traits::Is_V_in_regs) {
@@ -780,13 +799,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
780799
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
781800
__syncthreads();
782801
// Write dS to dBias
783-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>(
784-
gmem_tiled_copy_MaskBias,
802+
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/false>(
803+
gmem_tiled_copy_Bias,
785804
tBiassBias, tdBiasgdBias,
786-
tBiascBias,
787-
binfo.actual_seqlen_q - m_block * kBlockM,
788-
binfo.actual_seqlen_k - n_block * kBlockN
805+
tBiascBias, tBiaspBias,
806+
binfo.actual_seqlen_q - m_block * kBlockM
789807
);
808+
// Because copy_bias currently uses scalar loads, we need to sync here.
809+
// TODO: Remove sync after fixing to vectorized loads.
810+
__syncthreads();
790811

791812
// if (cute::thread0()) { print(tPrP); }
792813
// Layout p_l = tPrP.layout();
@@ -810,21 +831,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
810831
if (m_block > m_block_min) {
811832
// Advance gMask
812833
tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride));
813-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true, /*Bool_to_Element=*/true, Element>(
814-
gmem_tiled_copy_MaskBias,
815-
tMaskgMask, tMasksMask,
816-
tMaskcMask,
817-
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
818-
);
834+
// FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
835+
// gmem_tiled_copy_Mask,
836+
// tMaskgMask, tMasksMask,
837+
// tMaskcMask, tMaskpMask,
838+
// binfo.actual_seqlen_q - (m_block - 1) * kBlockM
839+
// );
819840
// FLASH_NAMESPACE::cp_async_fence();
820841
// FLASH_NAMESPACE::cp_async_wait<0>();
821-
__syncthreads();
842+
// // Do OR-reduce on the mask to see if any active threads for next iteration
822843

823-
// Do OR-reduce on the mask to see if any active threads for next iteration
824-
any_active_local_next = false;
825-
#pragma unroll
826-
for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); }
827-
any_active_next = __syncthreads_or(any_active_local_next);
844+
FLASH_NAMESPACE::copy_mask_with_or_reduce<Is_even_MN, /*Clear_OOB_MN=*/true, /*To_type=*/Element>(
845+
gmem_tiled_copy_Mask,
846+
tMaskgMask, tMasksMask,
847+
any_active_next,
848+
tMaskcMask, tMaskpMask,
849+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
850+
);
851+
// We don't need to syncthreads here because copy_mask is already or_syncthreads.
828852

829853
// Advance gdO
830854
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
@@ -926,12 +950,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
926950
tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride));
927951
tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride));
928952
if (any_active_next) {
929-
FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>(
930-
gmem_tiled_copy_MaskBias,
953+
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
954+
gmem_tiled_copy_Bias,
931955
tBiasgBias, tBiassBias,
932-
tBiascBias,
933-
binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
956+
tBiascBias, tBiaspBias,
957+
binfo.actual_seqlen_q - (m_block - 1) * kBlockM
934958
);
959+
// Because copy_bias currently uses scalar loads, we need to sync here.
960+
// TODO: Remove sync after fixing to vectorized loads.
961+
__syncthreads();
935962
}
936963
}
937964

0 commit comments

Comments
 (0)