From e69b1c71cd6533a3f8c571b7e5c6162cdefc69d1 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 15 Sep 2025 21:35:53 +0800 Subject: [PATCH 1/9] Improves numerical stability and initialization Uses in-place nan_to_num_ operation for better memory efficiency. Updates tensor sanitization to use dtype-specific infinity bounds instead of fixed values, preventing potential overflow issues. Changes tensor initialization from empty_like to zeros_like to ensure deterministic starting values for gradients. Fixes bias padding value from minimum float to zero for better numerical behavior. Enhances documentation to clarify support for flexible mask and bias head dimensions in MQA/GQA scenarios. --- flash_dmattn/flash_dmattn_interface.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 733b6e1..1d2ac22 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -14,7 +14,7 @@ def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None: for t in tensors: if t is not None and isinstance(t, torch.Tensor): - torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf, out=t) + torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf) def _get_block_size_n(device, head_dim, is_causal): @@ -95,7 +95,7 @@ def _flash_dmattn_forward( softcap, return_softmax, ) - _sanitize_tensors(out) + _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) return out, softmax_lse, S_dmask @@ -170,7 +170,7 @@ def _flash_dmattn_backward( softcap, deterministic, ) - _sanitize_tensors(dq, dk, dv, dbias) + _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min) return softmax_d @@ -255,7 +255,7 @@ def forward( if mask is not None: mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False) if bias is not None: - bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=torch.finfo(bias.dtype).min) + bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=0.0) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -288,7 +288,7 @@ def backward( *args: Any, ): q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors - dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) + dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) head_size_og = dout.size(3) dout_padded = dout @@ -339,11 +339,16 @@ def flash_dmattn_func( return_attn_probs: Optional[bool] = None, ): """ - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA. + For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension + of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0 + of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias. + If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 From c82f7dc69d622465806dd543c12110b40ddad765 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Tue, 16 Sep 2025 16:32:23 +0800 Subject: [PATCH 2/9] Removes sequence length padding logic Eliminates unnecessary padding of key and value tensors to multiples of 128 in sequence length dimension. Removes associated context saving and gradient unpadding operations that are no longer needed without the sequence length padding. Simplifies the forward and backward pass implementation by removing conditional padding logic for masks and biases. --- flash_dmattn/flash_dmattn_interface.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 1d2ac22..6e32b61 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -227,8 +227,6 @@ def forward( return_softmax: Optional[bool], is_grad_enabled: bool = True, ): - # q, k, v are expected to be of shape (batch_size, seqlen, num_heads, head_size) - seqlen_k = k.shape[1] is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] ) @@ -249,14 +247,6 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - if seqlen_k % 128 != 0: - k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) - v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 128 - seqlen_k % 128]) - if mask is not None: - mask = torch.nn.functional.pad(mask, [0, 128 - seqlen_k % 128], value=False) - if bias is not None: - bias = torch.nn.functional.pad(bias, [0, 128 - seqlen_k % 128], value=0.0) - out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, k, @@ -271,7 +261,6 @@ def forward( if is_grad: ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) - ctx.seqlen_k = seqlen_k ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal ctx.softcap = softcap @@ -318,11 +307,6 @@ def backward( dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - if ctx.seqlen_k % 128 != 0: - dk = dk[:, : ctx.seqlen_k, :, :] - dv = dv[:, : ctx.seqlen_k, :, :] - dbias = dbias[..., : ctx.seqlen_k] - return dq, dk, dv, None, dbias, None, None, None, None, None, None From 8a3bb04789cdbb7b8ed8524a16e15e224fcd0b9e Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Tue, 16 Sep 2025 16:35:59 +0800 Subject: [PATCH 3/9] Fixes out-of-bounds access in copy function Replaces vectorized copy with element-wise assignment to prevent memory access violations when bounds checking is disabled. Changes predicate handling to use dedicated predicate tensor instead of coordinate-based bounds checking for improved safety. Updates default Clear_OOB_MN to false and removes max_N parameter as bounds checking now relies on predicate tensor. --- csrc/flash_dmattn/src/utils.h | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index be28c10..f5c42f5 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -521,15 +521,16 @@ __forceinline__ __device__ void copy( //////////////////////////////////////////////////////////////////////////////////////////////////// template < - bool Is_even_MN=true, bool Clear_OOB_MN=true, bool Bool_to_Element=false, typename To_type=void, - typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, - typename Engine2, typename Layout2 + bool Is_even_MN=true, bool Clear_OOB_MN=false, bool Bool_to_Element=false, typename To_type=void, + // typename TiledCopy, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3 > __forceinline__ __device__ void copy_MN( - TiledCopy tiled_copy, + // TiledCopy tiled_copy, Tensor const &S, Tensor &D, - Tensor const &identity_MN, - const int max_M=0, const int max_N=0 + Tensor const &identity_MN, Tensor const &predicate_N, + const int max_M=0 ) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) @@ -542,14 +543,19 @@ __forceinline__ __device__ void copy_MN( if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { #pragma unroll for (int n = 0; n < size<2>(S); ++n) { - if (Is_even_MN || get<1>(identity_MN(0, m, n)) < max_N) { + if (Is_even_MN || predicate_N(n)) { if constexpr (Bool_to_Element) { #pragma unroll for (int i = 0; i < size<0>(S); ++i) { D(i, m, n) = static_cast(S(i, m, n)) ? To_type(1) : To_type(0); } } else { - cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + // Using vectorized load will cause out-of-bounds access when !Is_even_MN && !predicate_N(n) + // cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + D(i, m, n) = S(i, m, n); + } } } else if (Clear_OOB_MN) { cute::clear(D(_, m, n)); From e23b08f8f5636c875e3023e0202995b33a21874b Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 11:46:12 +0800 Subject: [PATCH 4/9] Adds head size padding and comments sequence length padding Improves memory alignment by ensuring head dimensions are padded to multiples of 8 for 16-bit memory allocations. Comments out sequence length padding implementation for future consideration, including corresponding mask and bias padding logic in both forward and backward passes. --- flash_dmattn/flash_dmattn_interface.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 6e32b61..9fa64d2 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -241,11 +241,20 @@ def forward( if return_softmax is None: return_softmax = False + # Padding to multiple of 8 for 16-bit memory allocations head_size_og = q.size(3) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + # seqlen_k_og = k.shape[1] + # if seqlen_k_og % 8 != 0: + # k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + # v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + # if mask is not None: + # mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) + # if bias is not None: + # bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, @@ -265,6 +274,7 @@ def forward( ctx.is_causal = is_causal ctx.softcap = softcap ctx.deterministic = deterministic + # ctx.seqlen_k_og = seqlen_k_og out = out_padded[..., :head_size_og] @@ -307,6 +317,11 @@ def backward( dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] + # if ctx.seqlen_k_og % 8 != 0: + # dk = dk[:, : ctx.seqlen_k_og, :, :] + # dv = dv[:, : ctx.seqlen_k_og, :, :] + # dbias = dbias[..., : ctx.seqlen_k_og] + return dq, dk, dv, None, dbias, None, None, None, None, None, None From 510aaf522e7c80f1a86636734f9f8ab72ce47127 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 12:49:05 +0800 Subject: [PATCH 5/9] Adds mask reduction utility function Implements a device function that performs logical OR reduction across mask tensor elements and synchronizes the result across thread blocks using warp-level primitives. Enables efficient sparse attention pattern processing by allowing threads to collectively determine if any mask elements are active within a given region. --- csrc/flash_dmattn/src/utils.h | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index f5c42f5..0f00038 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -333,6 +333,25 @@ __forceinline__ __device__ void sparse_gemm_rs( } } +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template +__forceinline__ __device__ void mask_or_reduce( + Tensor &tSsMask, + bool &active, + ThrCopy smem_thr_copy_Mask +) { + Tensor tSsMask_copy_view = smem_thr_copy_Mask.retile_D(tSsMask); + bool active_local = false; + #pragma unroll + for (int i = 0; i < size(tSsMask_copy_view); ++i) { + active_local |= tSsMask_copy_view(i); + } + active = __syncthreads_or(active_local); +} + + //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) From a148a3a9c2907c27cfc27704e90912a21d1e348c Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 12:52:01 +0800 Subject: [PATCH 6/9] Refactors copy function into specialized variants Splits the generic copy_MN function into four specialized functions: - copy_MN for basic tensor copying with tiled copy operations - copy_mask for masked copying operations - copy_mask_with_or_reduce for copying with OR reduction and block activity tracking - copy_bias for bias-specific copying with element-wise assignment Removes the Bool_to_Element template parameter and related conditional logic, simplifying the codebase by creating purpose-specific functions instead of a single overloaded function with multiple behaviors. --- csrc/flash_dmattn/src/utils.h | 141 ++++++++++++++++++++++++++++++---- 1 file changed, 126 insertions(+), 15 deletions(-) diff --git a/csrc/flash_dmattn/src/utils.h b/csrc/flash_dmattn/src/utils.h index 0f00038..52fb7a0 100644 --- a/csrc/flash_dmattn/src/utils.h +++ b/csrc/flash_dmattn/src/utils.h @@ -540,13 +540,132 @@ __forceinline__ __device__ void copy( //////////////////////////////////////////////////////////////////////////////////////////////////// template < - bool Is_even_MN=true, bool Clear_OOB_MN=false, bool Bool_to_Element=false, typename To_type=void, - // typename TiledCopy, + bool Is_even_MN=true, bool Clear_OOB_MN=false, + typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2, typename Engine3, typename Layout3 > __forceinline__ __device__ void copy_MN( - // TiledCopy tiled_copy, + TiledCopy tiled_copy, + Tensor const &S, Tensor &D, + Tensor const &identity_MN, Tensor const &predicate_N, + const int max_M=0 +) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N + + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { + #pragma unroll + for (int n = 0; n < size<2>(S); ++n) { + if (Is_even_MN || predicate_N(n)) { + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool Is_even_MN=true, bool Clear_OOB_MN=false, + typename TiledCopy, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3 +> +__forceinline__ __device__ void copy_mask( + TiledCopy tiled_copy, + Tensor const &S, Tensor &D, + Tensor const &identity_MN, Tensor const &predicate_N, + const int max_M=0 +) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N + + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { + #pragma unroll + for (int n = 0; n < size<2>(S); ++n) { + if (Is_even_MN || predicate_N(n)) { + cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool Is_even_MN=true, bool Clear_OOB_MN=false, typename To_type=void, + typename TiledCopy, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3 +> +__forceinline__ __device__ void copy_mask_with_or_reduce( + TiledCopy tiled_copy, + Tensor const &S, Tensor &D, + bool &block_active, + Tensor const &identity_MN, Tensor const &predicate_N, + const int max_M=0 +) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); // (MMA, MMA_M, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_N + + bool any_active = false; + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_M) { + #pragma unroll + for (int n = 0; n < size<2>(S); ++n) { + if (Is_even_MN || predicate_N(n)) { + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + any_active |= S(i, m, n); + D(i, m, n) = static_cast(S(i, m, n)); + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, n)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + + block_active = __syncthreads_or(any_active); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool Is_even_MN=true, bool Clear_OOB_MN=false, + typename TiledCopy, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3 +> +__forceinline__ __device__ void copy_bias( + TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_N, const int max_M=0 @@ -563,18 +682,10 @@ __forceinline__ __device__ void copy_MN( #pragma unroll for (int n = 0; n < size<2>(S); ++n) { if (Is_even_MN || predicate_N(n)) { - if constexpr (Bool_to_Element) { - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - D(i, m, n) = static_cast(S(i, m, n)) ? To_type(1) : To_type(0); - } - } else { - // Using vectorized load will cause out-of-bounds access when !Is_even_MN && !predicate_N(n) - // cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); - #pragma unroll - for (int i = 0; i < size<0>(S); ++i) { - D(i, m, n) = S(i, m, n); - } + // cute::copy(tiled_copy, S(_, m, n), D(_, m, n)); + #pragma unroll + for (int i = 0; i < size<0>(S); ++i) { + D(i, m, n) = S(i, m, n); } } else if (Clear_OOB_MN) { cute::clear(D(_, m, n)); From a0475b295e33fbd2329e32ee50f85af592fe3be6 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 12:52:44 +0800 Subject: [PATCH 7/9] Separates mask and bias memory operations in attention kernel Refactors combined mask-bias memory operations into separate dedicated operations to improve performance and maintainability. Introduces specialized copy functions for mask and bias operations with proper bounds checking and OR-reduction for mask activity detection. Removes redundant synchronization points by leveraging built-in synchronization in the new copy functions. Adds predicate tensor allocation for proper boundary handling in both regular and split-KV attention kernels. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 331 +++++++++++++---------- 1 file changed, 195 insertions(+), 136 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index f79e477..5b701bf 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -230,8 +230,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; - auto gmem_thr_copy_MaskBias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; + auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; + auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -239,10 +241,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tMaskgMask = gmem_thr_copy_MaskBias.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) - Tensor tMasksMask = gmem_thr_copy_MaskBias.partition_D(sMask); - Tensor tBiasgBias = gmem_thr_copy_MaskBias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) - Tensor tBiassBias = gmem_thr_copy_MaskBias.partition_D(sBias); + Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) + Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) + Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -300,8 +302,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_MaskBias.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_MaskBias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -310,18 +312,32 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Set predicates for k bounds if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { - tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { - tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Reverse iteration over N blocks + int n_block = n_block_max - 1; + + // Allocate predicate tensors for n + Tensor tMaskpMask = make_tensor(make_shape(size<2>(tMasksMask))); + Tensor tBiaspBias = make_tensor(make_shape(size<2>(tBiassBias))); + + // Set predicates for n bounds + if (!Is_even_MN) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } } // Prologue + bool any_active = true; // to be updated later for current iteration + bool any_active_next = true; // to be updated later for next iteration + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, @@ -344,27 +360,25 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } - // Reverse iteration over N blocks - int n_block = n_block_max - 1; - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask(_, _, _, n_block), tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads - // Do OR-reduce on the mask to see if any active threads - Tensor tSsMask_copy_view = smem_thr_copy_Mask.retile_S(tSsMask); - bool any_active_local = false; - bool any_active_local_next = false; // to be updated later for next iteration - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } - bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = false; // to be updated later for next iteration + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block), tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -374,12 +388,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block), tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); cute::cp_async_fence(); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } @@ -472,21 +489,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if (n_block > n_block_min) { - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask(_, _, _, n_block - 1), tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block - 1), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -494,12 +514,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -595,34 +618,40 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } if (n_block > n_block_min) { - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask(_, _, _, n_block - 1), tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block - 1), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); - - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads + if (any_active_next) { FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias(_, _, _, n_block - 1), tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -920,8 +949,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; - auto gmem_thr_copy_MaskBias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; + auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; + auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -929,10 +961,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tMaskgMask = gmem_thr_copy_MaskBias.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) - Tensor tMasksMask = gmem_thr_copy_MaskBias.partition_D(sMask); - Tensor tBiasgBias = gmem_thr_copy_MaskBias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) - Tensor tBiassBias = gmem_thr_copy_MaskBias.partition_D(sBias); + Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) + Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) + Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; @@ -972,26 +1004,40 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_MaskBias.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_MaskBias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + // Set predicates for k bounds if (!Is_even_K) { #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { - tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { - tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; - } + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } + int n_block = n_block_max - 1; + + // Allocate predicate tensors for m and n bounds + Tensor tMaskpMask = make_tensor(make_shape(size<2>(tMasksMask))); + Tensor tBiaspBias = make_tensor(make_shape(size<2>(tBiassBias))); + + // Set predicates for n bounds + if (!Is_even_MN) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + } // Prologue + bool any_active = true; // to be updated later for current iteration + bool any_active_next = true; // Load the first Q, K, V, Mask, Bias tiles + // Read Q from gmem to smem // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs FLASH_NAMESPACE::copy( @@ -1001,26 +1047,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM ); - int n_block = n_block_max - 1; - - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads - // Do OR-reduce on the mask to see if any active threads for next iteration - Tensor tSsMask_copy_view = smem_thr_copy_Mask.retile_S(tSsMask); - bool any_active_local = false; - bool any_active_local_next = false; // to be updated later for next iteration - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } - bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = false; // to be updated later for next iteration + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -1030,12 +1074,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); cute::cp_async_fence(); } @@ -1150,21 +1197,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); } - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); - - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -1172,12 +1222,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -1291,21 +1344,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); } - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -1313,12 +1369,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); From 3b7b57be38b9181cb8d7afedb0bfcc029e300cfc Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 12:52:58 +0800 Subject: [PATCH 8/9] Separates mask and bias memory operations for performance Splits the unified GmemTiledCopyMaskBias into separate GmemTiledCopyMask and GmemTiledCopyBias operations to enable independent optimization of memory access patterns. Introduces specialized copy_mask_with_or_reduce and copy_bias functions that replace generic copy_MN calls, allowing for better memory coalescing and reduced synchronization overhead. Adds predicate tensor allocation for bounds checking on the N dimension to improve memory safety and enable more efficient vectorized operations in future optimizations. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 113 ++++++++++++++--------- 1 file changed, 70 insertions(+), 43 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 2ba7d2a..dc93b76 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -272,9 +272,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Global to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyMaskBias gmem_tiled_copy_MaskBias; - auto gmem_thr_copy_Mask = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); - auto gmem_thr_copy_Bias = gmem_tiled_copy_MaskBias.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; + auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); + auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); using GmemTiledCopydO = std::conditional_t; GmemTiledCopydO gmem_tiled_copy_dO; auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); @@ -417,9 +418,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } + // Allocate predicate tensors for N + Tensor tMaskpMask = make_tensor(make_shape(size<2>(tMasksMask))); + Tensor tBiaspBias = make_tensor(make_shape(size<2>(tBiassBias))); + + // Set predicates for n bounds + if (!Is_even_MN) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } + // Prologue + bool any_active = true; // to be updated later for current iteration + bool any_active_next = true; // to be updated later for next iteration + // We'll advance gdQ, gdQaccum and gdBias before the 1st read/write. tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; @@ -554,24 +570,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); // cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads - // Do OR-reduce on the mask to see if any active threads - Tensor tSsMask_copy_view = smem_thr_copy_PdS.retile_S(tSsMask); - bool any_active_local = false; - bool any_active_local_next = false; // to be updated later for next iteration - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local |= (tSsMask_copy_view(i) != Element(0)); } - bool any_active = __syncthreads_or(any_active_local); - bool any_active_next = false; // to be updated later for next iteration + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, @@ -581,12 +597,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in ); if (any_active) { - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); } if (!Kernel_traits::Is_V_in_regs) { @@ -780,13 +799,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); __syncthreads(); // Write dS to dBias - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiassBias, tdBiasgdBias, - tBiascBias, - binfo.actual_seqlen_q - m_block * kBlockM, - binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); // if (cute::thread0()) { print(tPrP); } // Layout p_l = tPrP.layout(); @@ -810,21 +831,24 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (m_block > m_block_min) { // Advance gMask tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, - tMaskgMask, tMasksMask, - tMaskcMask, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN - ); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - (m_block - 1) * kBlockM + // ); // FLASH_NAMESPACE::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); - __syncthreads(); + // // Do OR-reduce on the mask to see if any active threads for next iteration - // Do OR-reduce on the mask to see if any active threads for next iteration - any_active_local_next = false; - #pragma unroll - for (int i = 0; i < size(tSsMask_copy_view); ++i) { any_active_local_next |= (tSsMask_copy_view(i) != Element(0)); } - any_active_next = __syncthreads_or(any_active_local_next); + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. // Advance gdO tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); @@ -926,12 +950,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); if (any_active_next) { - FLASH_NAMESPACE::copy_MN( - gmem_tiled_copy_MaskBias, + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, tBiasgBias, tBiassBias, - tBiascBias, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); } } From e9f9fccebd6053c99a3981110cebb556478acb47 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 12:53:17 +0800 Subject: [PATCH 9/9] Separates mask and bias copy operations into distinct types Splits the combined GmemTiledCopyMaskBias type into separate GmemTiledCopyMask and GmemTiledCopyBias types in both forward and backward kernel traits. This separation improves code clarity and allows for independent handling of mask and bias copy operations, enabling more flexible memory access patterns and potential optimizations. --- csrc/flash_dmattn/src/kernel_traits.h | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index a574675..8f648e1 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -183,7 +183,14 @@ struct Flash_fwd_kernel_traits : public Base { Layout>{} ) ); // Val layout, 8 vals per read - using GmemTiledCopyMaskBias = decltype( + using GmemTiledCopyMask = decltype( + make_tiled_copy( + Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>{} + ) + ); // Val layout, 8 vals per read + using GmemTiledCopyBias = decltype( make_tiled_copy( Copy_Atom, Element>{}, GmemLayoutAtom{}, @@ -442,7 +449,14 @@ struct Flash_bwd_kernel_traits : public Base { Layout>{} ) ); // Val layout, 8 vals per store - using GmemTiledCopyMaskBias = decltype( + using GmemTiledCopyMask = decltype( + make_tiled_copy( + Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{} + ) + ); // Val layout, 8 vals per read + using GmemTiledCopyBias = decltype( make_tiled_copy( Copy_Atom, elem_type>{}, GmemLayoutAtom{},