diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 2ba7d2a..dc93b76 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -272,9 +272,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; - auto gmem_thr_copy_Mask = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); - auto gmem_thr_copy_Bias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; + auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); + auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); using GmemTiledCopydO = std::conditional_t; GmemTiledCopydO gmem_tiled_copy_dO; auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); @@ -417,9 +418,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } + // Allocate predicate tensors for N + Tensor tMaskpMask = make_tensor(make_shape(size<2>(tMasksMask))); + Tensor tBiaspBias = make_tensor(make_shape(size<2>(tBiassBias))); + + // Set predicates for n bounds + if (!Is_even_MN) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } + // Prologue + bool any_active = true; // to be updated later for current iteration + bool any_active_next = true; // to be updated later for next iteration + // We'll advance gdQ, gdQaccum and gdBias before the 1st read/write. tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; @@ -554,24 +570,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads - // Do OR-reduce on the mask to see if any active threads - Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask); - bool any_active_local = false; - bool any_active_local_next = false; // to be updated later for next iteration - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } - bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = false; // to be updated later for next iteration + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, @@ -581,12 +597,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in ); if (any_active) { - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); } if (!Kernel_traits::Is_V_in_regs) { @@ -780,13 +799,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); __syncthreads(); // Write dS to dBias - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiassBias, tdBiasgdBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, - binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // if (cute::thread0()) { print(tPrP); } // Layout p_l = tPrP.layout(); @@ -810,21 +831,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (m_block > m_block_min) { // Advance gMask tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - (m_block - 1) * kBlockM + // ); // FLASH_NAMESPACE::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads for next iteration - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. // Advance gdO tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); @@ -926,12 +950,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); if (any_active_next) { - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); } } diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index f79e477..5b701bf 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -230,8 +230,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; - auto gmem_thr_copy_MaskBias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; + auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; + auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -239,10 +241,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tMaskgMask = gmem_thr_copy_MaskBias.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) - Tensor tMasksMask = gmem_thr_copy_MaskBias.partition_D(sMask); - Tensor tBiasgBias = gmem_thr_copy_MaskBias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) - Tensor tBiassBias = gmem_thr_copy_MaskBias.partition_D(sBias); + Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) + Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) + Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -300,8 +302,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_MaskBias.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_MaskBias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -310,18 +312,32 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Set predicates for k bounds if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { - tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { - tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Reverse iteration over N blocks + int n_block = n_block_max - 1; + + // Allocate predicate tensors for n + Tensor tMaskpMask = make_tensor(make_shape(size<2>(tMasksMask))); + Tensor tBiaspBias = make_tensor(make_shape(size<2>(tBiassBias))); + + // Set predicates for n bounds + if (!Is_even_MN) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } } // Prologue + bool any_active = true; // to be updated later for current iteration + bool any_active_next = true; // to be updated later for next iteration + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, @@ -344,27 +360,25 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } - // Reverse iteration over N blocks - int n_block = n_block_max - 1; - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask(_, _, _, n_block), tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads - // Do OR-reduce on the mask to see if any active threads - Tensor tSsMask_copy_view = smem_thr_copy_Mask.retile_S(tSsMask); - bool any_active_local = false; - bool any_active_local_next = false; // to be updated later for next iteration - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } - bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = false; // to be updated later for next iteration + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block), tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -374,12 +388,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block), tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); cute::cp_async_fence(); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } @@ -472,21 +489,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if (n_block > n_block_min) { - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask(_, _, _, n_block - 1), tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block - 1), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -494,12 +514,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -595,34 +618,40 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if (n_block > n_block_min) { - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask(_, _, _, n_block - 1), tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block - 1), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); - - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads + if (any_active_next) { FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -920,8 +949,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; - auto gmem_thr_copy_MaskBias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; + auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; + auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -929,10 +961,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tMaskgMask = gmem_thr_copy_MaskBias.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) - Tensor tMasksMask = gmem_thr_copy_MaskBias.partition_D(sMask); - Tensor tBiasgBias = gmem_thr_copy_MaskBias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) - Tensor tBiassBias = gmem_thr_copy_MaskBias.partition_D(sBias); + Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) + Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) + Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -972,26 +1004,40 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_MaskBias.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_MaskBias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + // Set predicates for k bounds if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { - tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { - tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } + int n_block = n_block_max - 1; + + // Allocate predicate tensors for m and n bounds + Tensor tMaskpMask = make_tensor(make_shape(size<2>(tMasksMask))); + Tensor tBiaspBias = make_tensor(make_shape(size<2>(tBiassBias))); + + // Set predicates for n bounds + if (!Is_even_MN) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + } // Prologue + bool any_active = true; // to be updated later for current iteration + bool any_active_next = true; // Load the first Q, K, V, Mask, Bias tiles + // Read Q from gmem to smem // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs FLASH_NAMESPACE::copy( @@ -1001,26 +1047,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM ); - int n_block = n_block_max - 1; - - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads - // Do OR-reduce on the mask to see if any active threads for next iteration - Tensor tSsMask_copy_view = smem_thr_copy_Mask.retile_S(tSsMask); - bool any_active_local = false; - bool any_active_local_next = false; // to be updated later for next iteration - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } - bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = false; // to be updated later for next iteration + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -1030,12 +1074,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); cute::cp_async_fence(); } @@ -1150,21 +1197,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); } - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); - - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -1172,12 +1222,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -1291,21 +1344,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); } - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -1313,12 +1369,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index a574675..8f648e1 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -183,7 +183,14 @@ struct Flash_fwd_kernel_traits : public Base { Layout>{} ) ); // Val layout, 8 vals per read - using GmemTiledCopyMaskBias = decltype( + using GmemTiledCopyMask = decltype( + make_tiled_copy( + Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>{} + ) + ); // Val layout, 8 vals per read + using GmemTiledCopyBias = decltype( make_tiled_copy( Copy_Atom, Element>{}, GmemLayoutAtom{}, @@ -442,7 +449,14 @@ struct Flash_bwd_kernel_traits : public Base { Layout>{} ) ); // Val layout, 8 vals per store - using GmemTiledCopyMaskBias = decltype( + using GmemTiledCopyMask = decltype( + make_tiled_copy( + Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{} + ) + ); // Val layout, 8 vals per read + using GmemTiledCopyBias = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, GmemLayoutAtom{}, diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index be28c10..52fb7a0 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -333,6 +333,25 @@ __forceinline__ __device__ void sparse_gemm_rs( } } +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template +__forceinline__ __device__ void mask_or_reduce( + Tensor &tSsMask, + bool &active, + ThrCopy smem_thr_copy_Mask +) { + Tensor tSsMask_copy_view = smem_thr_copy_Mask.retile_D(tSsMask); + bool active_local = false; + #pragma unroll + for (int i = 0; i < size(tSsMask_copy_view); ++i) { + active_local |= tSsMask_copy_view(i); + } + active = __syncthreads_or(active_local); +} + + //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) @@ -521,15 +540,135 @@ __forceinline__ __device__ void copy( //////////////////////////////////////////////////////////////////////////////////////////////////// template < - bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void, - typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, - typename Engine2, typename Layout2 + bool Is_even_MN=true, bool Clear_OOB_MN=false, + typename TiledCopy, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3 > __forceinline__ __device__ void copy_MN( TiledCopy tiled_copy, Tensor const &S, Tensor &D, - Tensor const &identity_MN, - const int max_M=0, const int max_N=0 + Tensor const &identity_MN, Tensor const &predicate_N, + const int max_M=0 +) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N + + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { + #pragma unroll + for (int n = 0; n < size<2>(S); ++n) { + if (Is_even_MN || predicate_N(n)) { + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool Is_even_MN=true, bool Clear_OOB_MN=false, + typename TiledCopy, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3 +> +__forceinline__ __device__ void copy_mask( + TiledCopy tiled_copy, + Tensor const &S, Tensor &D, + Tensor const &identity_MN, Tensor const &predicate_N, + const int max_M=0 +) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N + + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { + #pragma unroll + for (int n = 0; n < size<2>(S); ++n) { + if (Is_even_MN || predicate_N(n)) { + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool Is_even_MN=true, bool Clear_OOB_MN=false, typename To_type=void, + typename TiledCopy, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3 +> +__forceinline__ __device__ void copy_mask_with_or_reduce( + TiledCopy tiled_copy, + Tensor const &S, Tensor &D, + bool &block_active, + Tensor const &identity_MN, Tensor const &predicate_N, + const int max_M=0 +) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N + + bool any_active = false; + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { + #pragma unroll + for (int n = 0; n < size<2>(S); ++n) { + if (Is_even_MN || predicate_N(n)) { + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + any_active |= S(i, m, n); + D(i, m, n) = static_cast(S(i, m, n)); + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + + block_active = __syncthreads_or(any_active); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool Is_even_MN=true, bool Clear_OOB_MN=false, + typename TiledCopy, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3 +> +__forceinline__ __device__ void copy_bias( + TiledCopy tiled_copy, + Tensor const &S, Tensor &D, + Tensor const &identity_MN, Tensor const &predicate_N, + const int max_M=0 ) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) @@ -542,14 +681,11 @@ __forceinline__ __device__ void copy_MN( if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { - if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) { - if constexpr (Bool_to_Element) { - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - D(i, m, n) = static_cast(S(i, m, n)) ? To_type(1) : To_type(0); - } - } else { - cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + if (Is_even_MN || predicate_N(n)) { + // cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + D(i, m, n) = S(i, m, n); } } else if (Clear_OOB_MN) { cute::clear(D(_, m, n)); diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 733b6e1..9fa64d2 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -14,7 +14,7 @@ def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None: for t in tensors: if t is not None and isinstance(t, torch.Tensor): - torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf, out=t) + torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf) def _get_block_size_n(device, head_dim, is_causal): @@ -95,7 +95,7 @@ def _flash_dmattn_forward( softcap, return_softmax, ) - _sanitize_tensors(out) + _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) return out, softmax_lse, S_dmask @@ -170,7 +170,7 @@ def _flash_dmattn_backward( softcap, deterministic, ) - _sanitize_tensors(dq, dk, dv, dbias) + _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min) return softmax_d @@ -227,8 +227,6 @@ def forward( return_softmax: Optional[bool], is_grad_enabled: bool = True, ): - # q, k, v are expected to be of shape (batch_size, seqlen, num_heads, head_size) - seqlen_k = k.shape[1] is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] ) @@ -243,19 +241,20 @@ def forward( if return_softmax is None: return_softmax = False + # Padding to multiple of 8 for 16-bit memory allocations head_size_og = q.size(3) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - if seqlen_k % 128 != 0: - k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) - v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) - if mask is not None: - mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False) - if bias is not None: - bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min) + # seqlen_k_og = k.shape[1] + # if seqlen_k_og % 8 != 0: + # k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + # v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + # if mask is not None: + # mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) + # if bias is not None: + # bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -271,11 +270,11 @@ def forward( if is_grad: ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) - ctx.seqlen_k = seqlen_k ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic + # ctx.seqlen_k_og = seqlen_k_og out = out_padded[..., :head_size_og] @@ -288,7 +287,7 @@ def backward( *args: Any, ): q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors - dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) + dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) head_size_og = dout.size(3) dout_padded = dout @@ -318,10 +317,10 @@ def backward( dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - if ctx.seqlen_k % 128 != 0: - dk = dk[:, : ctx.seqlen_k, :, :] - dv = dv[:, : ctx.seqlen_k, :, :] - dbias = dbias[..., : ctx.seqlen_k] + # if ctx.seqlen_k_og % 8 != 0: + # dk = dk[:, : ctx.seqlen_k_og, :, :] + # dv = dv[:, : ctx.seqlen_k_og, :, :] + # dbias = dbias[..., : ctx.seqlen_k_og] return dq, dk, dv, None, dbias, None, None, None, None, None, None @@ -339,11 +338,16 @@ def flash_dmattn_func( return_attn_probs: Optional[bool] = None, ): """ - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA. + For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension + of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0 + of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias. + If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0