From d18cdacd33f7cbe6162d208e0fbb7b7006a395cd Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 13 Sep 2025 18:36:45 +0800 Subject: [PATCH 1/7] Adds head count fields to mask and bias parameters Introduces h_mask and h_bias fields to track the number of heads in attention mask and bias structures respectively. Enables better head dimension management and validation in flash attention operations. --- csrc/flash_dmattn/src/flash.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index c1cb7f4..a533cc3 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -50,6 +50,9 @@ struct Mask_params { index_t mask_batch_stride; // Stride between batches of attention mask index_t mask_head_stride; // Stride between heads of attention mask index_t mask_row_stride; // Stride between rows of attention mask + + // The number of heads in the mask. + int h_mask; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -61,6 +64,9 @@ struct Bias_params { index_t bias_batch_stride; // Stride between batches of attention bias index_t bias_head_stride; // Stride between heads of attention bias index_t bias_row_stride; // Stride between rows of attention bias + + // The number of heads in the bias. + int h_bias; }; //////////////////////////////////////////////////////////////////////////////////////////////////// From 22394378493a434e3918d2a01285cdce0f25d506 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 13 Sep 2025 18:37:33 +0800 Subject: [PATCH 2/7] Improves mask and bias head indexing flexibility Introduces dynamic head index calculation for mask and bias tensors to support different head configurations. Previously used fixed head ratio calculations, now supports three scenarios: - Single head broadcasting (h_mask/h_bias == 1) - Multi-head with ratio-based indexing (h_mask/h_bias == h_k) - Direct head indexing (fallback case) Enables more flexible attention masking and bias application across different multi-head attention configurations. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 79dcbfb..f79e477 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -136,6 +136,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // might save us 1 register (we just need n_block instead of both n_block and n_block_max). const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); + const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); // Global memory tensor configuration Tensor mQ = make_tensor( @@ -170,21 +172,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // (kBlockN, kHeadDim, nblocksN) Tensor mMask = make_tensor( make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), - make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); Tensor gMask = local_tile( - mMask(bidh / params.h_h_k_ratio, _, _), + mMask(h_idx_mask, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) Tensor mBias = make_tensor( make_gmem_ptr(reinterpret_cast(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)), - make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_shape(params.h_bias, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.bias_head_stride, params.bias_row_stride, _1{}) ); Tensor gBias = local_tile( - mBias(bidh / params.h_h_k_ratio, _, _), + mBias(h_idx_bias, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) @@ -840,16 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t col_offset_mask = (block_table == nullptr) ? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN + + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; + + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; + const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t col_offset_bias = (block_table == nullptr) ? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN + + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; + + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; // Global memory tensor configuration Tensor mQ = make_tensor( From 3bc89c86ef8897f1f4d22b846179677a92ee2833 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 13 Sep 2025 18:37:54 +0800 Subject: [PATCH 3/7] Adds flexible head indexing for mask and bias tensors Introduces conditional head index calculation for mask and bias operations based on tensor dimensions. Supports scenarios where mask/bias tensors can have single head (h=1), match key heads (h=h_k), or match query heads (h=h_q). Replaces hardcoded head index division with dynamic selection logic that adapts to different tensor head configurations in flash attention backward kernel. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 723265b..2ba7d2a 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -107,10 +107,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb) - + (bidh / params.h_h_k_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN; + + h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN; + const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb) - + (bidh / params.h_h_k_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN; + + h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN; const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb) + bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) From 4e505b2aed29c62b9c35aff46c527e8bbcf029bd Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 13 Sep 2025 18:38:13 +0800 Subject: [PATCH 4/7] Supports flexible mask and bias head dimensions Adds support for mask and bias tensors with 1, num_heads_k, or num_heads dimensions instead of only num_heads_k. Enables more flexible attention patterns by allowing masks and biases to be broadcast across different head configurations. Updates parameter passing to track separate head counts for masks and biases, and adds appropriate validation checks. Temporarily disables variable-length attention variants to focus on core functionality improvements. --- csrc/flash_dmattn/flash_api.cpp | 1002 +++++++++++++++++-------------- 1 file changed, 542 insertions(+), 460 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index b4a657e..4108cb4 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -31,6 +31,8 @@ void set_params_fprop( const size_t seqlen_k_rounded, const size_t h, const size_t h_k, + const size_t h_mask, + const size_t h_bias, const size_t d, const size_t d_rounded, // device pointers @@ -107,6 +109,8 @@ void set_params_fprop( params.h = h; params.h_k = h_k; params.h_h_k_ratio = h / h_k; + params.h_mask = h_mask; + params.h_bias = h_bias; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; @@ -150,6 +154,8 @@ void set_params_dgrad( const size_t seqlen_k_rounded, const size_t h, const size_t h_k, + const size_t h_mask, + const size_t h_bias, const size_t d, const size_t d_rounded, // device pointers @@ -179,7 +185,7 @@ void set_params_dgrad( ) { set_params_fprop( params, - b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, h_mask, h_bias, d, d_rounded, q, k, v, mask, bias, out, cu_seqlens_q_d, cu_seqlens_k_d, @@ -341,8 +347,8 @@ mha_fwd( at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k + at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k + at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k, or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float softmax_scale, bool is_causal, @@ -380,10 +386,14 @@ mha_fwd( const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); + int num_heads_mask = mask.size(1); + int num_heads_bias = bias.size(1); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } @@ -392,12 +402,26 @@ mha_fwd( // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; const int ngroups = num_heads / num_heads_k; - at::Tensor mask_view = mask; - at::Tensor bias_view = bias; + const int orig_num_heads_mask = num_heads_mask; + const int orig_num_heads_bias = num_heads_bias; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); - mask_view = mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}); - bias_view = bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}); + if (num_heads_mask == 1) { + mask = mask.expand({batch_size, 1, ngroups, seqlen_k}); + } else if (num_heads_mask == num_heads_k) { + mask = mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}); + } else { // num_heads_mask == num_heads + mask = mask.reshape({batch_size, num_heads_k, ngroups, seqlen_k}); + } + if (num_heads_bias == 1) { + bias = bias.expand({batch_size, 1, ngroups, seqlen_k}); + } else if (num_heads_bias == num_heads_k) { + bias = bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}); + } else { // num_heads_bias == num_heads + bias = bias.reshape({batch_size, num_heads_k, ngroups, seqlen_k}); + } + num_heads_mask = (num_heads_mask == num_heads) ? num_heads_k : num_heads_mask; + num_heads_bias = (num_heads_bias == num_heads) ? num_heads_k : num_heads_bias; seqlen_q = ngroups; num_heads = num_heads_k; } @@ -405,8 +429,20 @@ mha_fwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(mask_view, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(bias_view, batch_size, num_heads_k, seqlen_q, seqlen_k); + if (num_heads_mask == 1) { + CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_mask == num_heads_k) { + CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); + } + if (num_heads_bias == 1) { + CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); + } at::Tensor out; if (out_.has_value()) { @@ -444,9 +480,9 @@ mha_fwd( batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_mask, num_heads_bias, head_size, head_size_rounded, - q, k, v, mask_view, bias_view, out, + q, k, v, mask, bias, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, @@ -477,231 +513,242 @@ mha_fwd( out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); - } - return {out, softmax_lse, p}; -} - -std::vector -mha_varlen_fwd( - at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k - std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - std::optional &leftpad_k_, // batch_size - std::optional &block_table_, // batch_size x max_num_blocks_per_seq - int max_seqlen_q, - const int max_seqlen_k, - const float softmax_scale, - const bool zero_tensors, - bool is_causal, - const float softcap, - const bool return_softmax -) { - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); - TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); - CHECK_DEVICE(cu_seqlens_q); - CHECK_DEVICE(cu_seqlens_k); - - at::Tensor block_table; - // const bool paged_KV = block_table_.has_value(); - const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. - if (paged_KV) { - block_table = block_table_.value(); - CHECK_DEVICE(block_table); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); - TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); - } - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - CHECK_CONTIGUOUS(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_k); - - const auto sizes = q.sizes(); - - const int batch_size = cu_seqlens_q.numel() - 1; - int num_heads = sizes[1]; - const int head_size = sizes[2]; - const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - - const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); - const int num_blocks = !paged_KV ? 0 : k.size(0); - const int page_block_size = !paged_KV ? 1 : k.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); - - if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - - void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); - - // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case - // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; - const int ngroups = num_heads / num_heads_k; - if (seqlenq_ngroups_swapped) { - q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); - max_seqlen_q = ngroups; - num_heads = num_heads_k; - cu_seqlens_q_d = nullptr; - } - - const int total_q = q.sizes()[0]; - - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); - TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - CHECK_SHAPE(q, total_q, num_heads, head_size); - if (!paged_KV) { - const int total_k = k.size(0); - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); - CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); - CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); - } else { - CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); - } - - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - if (seqused_k.has_value()){ - auto seqused_k_ = seqused_k.value(); - TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); - TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); - TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); - CHECK_SHAPE(seqused_k_, batch_size); - } - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, sizes[0], sizes[1], head_size); - if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + if (orig_num_heads_mask == 1 || orig_num_heads_mask == num_heads_k) { + mask = mask.narrow(2, 0, 1); + } else { // orig_num_heads_mask == num_heads + mask = mask.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + } + if (orig_num_heads_bias == 1 || orig_num_heads_bias == num_heads_k) { + bias = bias.narrow(2, 0, 1); + } else { // orig_num_heads_bias == num_heads + bias = bias.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); } - } else { - out = torch::empty_like(q); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - - auto opts = q.options(); - auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); - at::Tensor p; - - if (return_softmax) { - p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } else { - p = torch::empty({ 0 }, opts); - } - - if (zero_tensors) { - out.zero_(); - softmax_lse.fill_(-std::numeric_limits::infinity()); - if (return_softmax) { p.zero_(); } - } - - Flash_fwd_params params; - set_params_fprop( - params, - batch_size, - max_seqlen_q, max_seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q, k, v, mask, bias, out, - cu_seqlens_q_d, - cu_seqlens_k.data_ptr(), - seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, - return_softmax ? p.data_ptr() : nullptr, - softmax_lse.data_ptr(), - softmax_scale, - is_causal, - softcap, - seqlenq_ngroups_swapped, - /*unpadded_lse*/true - ); - params.total_q = total_q; - - if (paged_KV) { - params.block_table = block_table.data_ptr(); - params.block_table_batch_stride = block_table.stride(0); - params.k_batch_stride = k.stride(0); - params.v_batch_stride = v.stride(0); - } - params.page_block_size = page_block_size; - // Keep references to these tensors to extend their lifetime - at::Tensor softmax_lse_accum, out_accum; - if (seqlenq_ngroups_swapped) { - // Only apply split-k for decoding - std::tie(softmax_lse_accum, out_accum) = - set_params_splitkv( - params, batch_size, num_heads, head_size, - max_seqlen_k, max_seqlen_q, head_size_rounded, - /*num_splits*/ 0, get_num_sm(get_current_device()), opts - ); - } - - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_DEVICE(leftpad_k); - CHECK_CONTIGUOUS(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - params.leftpad_k = static_cast(leftpad_k.data_ptr()); - } - - if (max_seqlen_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream, paged_KV); - } else { - // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. - out.zero_(); - softmax_lse.fill_(std::numeric_limits::infinity()); - } - - if (seqlenq_ngroups_swapped) { - int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; - int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; - out = out.reshape(size_before).transpose(1, 2).reshape(size_after); - q = q.reshape(size_before).transpose(1, 2).reshape(size_after); - softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } - return {out, softmax_lse, p}; } +// TODO: At present, we don't have a good strategy to handle the mask and bias of the varlen variant. +// std::vector +// mha_varlen_fwd( +// at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +// const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. +// const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. +// const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k +// const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k +// std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i +// const at::Tensor &cu_seqlens_q, // b+1 +// const at::Tensor &cu_seqlens_k, // b+1 +// std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. +// std::optional &leftpad_k_, // batch_size +// std::optional &block_table_, // batch_size x max_num_blocks_per_seq +// int max_seqlen_q, +// const int max_seqlen_k, +// const float softmax_scale, +// const bool zero_tensors, +// bool is_causal, +// const float softcap, +// const bool return_softmax +// ) { +// // Otherwise the kernel will be launched from cuda:0 device +// at::cuda::CUDAGuard device_guard{q.device()}; +// auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); +// bool is_sm8x_min = cc_major >= 8; +// TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + +// auto q_dtype = q.dtype(); +// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); +// TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); +// TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); +// TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); +// TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); +// TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); +// TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + +// CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); +// CHECK_DEVICE(cu_seqlens_q); +// CHECK_DEVICE(cu_seqlens_k); + +// at::Tensor block_table; +// // const bool paged_KV = block_table_.has_value(); +// const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. +// if (paged_KV) { +// block_table = block_table_.value(); +// CHECK_DEVICE(block_table); +// TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); +// TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); +// } + +// TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// CHECK_CONTIGUOUS(cu_seqlens_q); +// CHECK_CONTIGUOUS(cu_seqlens_k); + +// const auto sizes = q.sizes(); + +// const int batch_size = cu_seqlens_q.numel() - 1; +// int num_heads = sizes[1]; +// const int head_size = sizes[2]; +// const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + +// const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); +// const int num_blocks = !paged_KV ? 0 : k.size(0); +// const int page_block_size = !paged_KV ? 1 : k.size(1); +// TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + +// if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + +// void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + +// // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case +// // H/t Daniel Haziza +// const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; +// const int ngroups = num_heads / num_heads_k; +// if (seqlenq_ngroups_swapped) { +// q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); +// max_seqlen_q = ngroups; +// num_heads = num_heads_k; +// cu_seqlens_q_d = nullptr; +// } + +// const int total_q = q.sizes()[0]; + +// TORCH_CHECK(batch_size > 0, "batch size must be positive"); +// TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); +// TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); +// TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + +// CHECK_SHAPE(q, total_q, num_heads, head_size); +// if (!paged_KV) { +// const int total_k = k.size(0); +// CHECK_SHAPE(k, total_k, num_heads_k, head_size); +// CHECK_SHAPE(v, total_k, num_heads_k, head_size); +// CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); +// CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); +// } else { +// CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); +// CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); +// CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); +// } + +// CHECK_SHAPE(cu_seqlens_q, batch_size + 1); +// CHECK_SHAPE(cu_seqlens_k, batch_size + 1); +// if (seqused_k.has_value()){ +// auto seqused_k_ = seqused_k.value(); +// TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); +// TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); +// TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); +// CHECK_SHAPE(seqused_k_, batch_size); +// } + +// at::Tensor out; +// if (out_.has_value()) { +// out = out_.value(); +// TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); +// CHECK_DEVICE(out); +// TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); +// CHECK_SHAPE(out, sizes[0], sizes[1], head_size); +// if (seqlenq_ngroups_swapped) { +// out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); +// } +// } else { +// out = torch::empty_like(q); +// } + +// auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; +// const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); +// const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); +// const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + +// auto opts = q.options(); +// auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); +// at::Tensor p; + +// if (return_softmax) { +// p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); +// } else { +// p = torch::empty({ 0 }, opts); +// } + +// if (zero_tensors) { +// out.zero_(); +// softmax_lse.fill_(-std::numeric_limits::infinity()); +// if (return_softmax) { p.zero_(); } +// } + +// Flash_fwd_params params; +// set_params_fprop( +// params, +// batch_size, +// max_seqlen_q, max_seqlen_k, +// seqlen_q_rounded, seqlen_k_rounded, +// num_heads, num_heads_k, +// head_size, head_size_rounded, +// q, k, v, mask, bias, out, +// cu_seqlens_q_d, +// cu_seqlens_k.data_ptr(), +// seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, +// return_softmax ? p.data_ptr() : nullptr, +// softmax_lse.data_ptr(), +// softmax_scale, +// is_causal, +// softcap, +// seqlenq_ngroups_swapped, +// /*unpadded_lse*/true +// ); +// params.total_q = total_q; + +// if (paged_KV) { +// params.block_table = block_table.data_ptr(); +// params.block_table_batch_stride = block_table.stride(0); +// params.k_batch_stride = k.stride(0); +// params.v_batch_stride = v.stride(0); +// } +// params.page_block_size = page_block_size; +// // Keep references to these tensors to extend their lifetime +// at::Tensor softmax_lse_accum, out_accum; +// if (seqlenq_ngroups_swapped) { +// // Only apply split-k for decoding +// std::tie(softmax_lse_accum, out_accum) = +// set_params_splitkv( +// params, batch_size, num_heads, head_size, +// max_seqlen_k, max_seqlen_q, head_size_rounded, +// /*num_splits*/ 0, get_num_sm(get_current_device()), opts +// ); +// } + +// if (leftpad_k_.has_value()) { +// auto leftpad_k = leftpad_k_.value(); +// TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); +// TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); +// CHECK_DEVICE(leftpad_k); +// CHECK_CONTIGUOUS(leftpad_k); +// CHECK_SHAPE(leftpad_k, batch_size); +// params.leftpad_k = static_cast(leftpad_k.data_ptr()); +// } + +// if (max_seqlen_k > 0) { +// auto stream = at::cuda::getCurrentCUDAStream().stream(); +// run_mha_fwd(params, stream, paged_KV); +// } else { +// // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. +// out.zero_(); +// softmax_lse.fill_(std::numeric_limits::infinity()); +// } + +// if (seqlenq_ngroups_swapped) { +// int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; +// int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; +// out = out.reshape(size_before).transpose(1, 2).reshape(size_after); +// q = q.reshape(size_before).transpose(1, 2).reshape(size_after); +// softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); +// } + +// return {out, softmax_lse, p}; +// } + void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { @@ -718,8 +765,8 @@ mha_bwd( const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k + const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x num_heads x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size @@ -774,10 +821,14 @@ mha_bwd( const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); + const int num_heads_mask = mask.size(1); + const int num_heads_bias = bias.size(1); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); @@ -787,8 +838,20 @@ mha_bwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + if (num_heads_mask == 1) { + CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_mask == num_heads_k) { + CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); + } + if (num_heads_bias == 1) { + CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); + } CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); @@ -825,9 +888,21 @@ mha_bwd( TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); CHECK_DEVICE(dbias); TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - CHECK_SHAPE(dbias, batch_size, num_heads_k, seqlen_q, seqlen_k); + if (num_heads_bias == 1) { + CHECK_SHAPE(dbias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(dbias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(dbias, batch_size, num_heads, seqlen_q, seqlen_k); + } } else { - dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); + if (num_heads_bias == 1) { + dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); + } else if (num_heads_bias == num_heads_k) { + dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); + } else { + dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } } // bool loop = seqlen_k > blocksize_c; @@ -852,10 +927,13 @@ mha_bwd( if (num_heads_k != num_heads) { // MQA / GQA dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dbias_expanded = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); } else { dk_expanded = dk; dv_expanded = dv; + } + if (num_heads_bias != num_heads) { + dbias_expanded = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } else { dbias_expanded = dbias; } @@ -866,7 +944,7 @@ mha_bwd( batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_mask, num_heads_bias, head_size, head_size_rounded, q, k, v, mask, bias, out, dout, dq, dk_expanded, dv_expanded, dbias_expanded, @@ -903,233 +981,237 @@ mha_bwd( if (num_heads_k != num_heads) { at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_k, num_heads / num_heads_k, seqlen_q, seqlen_k}), {2}); + } + // For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads + if (num_heads_bias != num_heads) { + at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}); } return { dq, dk, dv, dbias, softmax_d }; } -std::vector -mha_varlen_bwd( - const at::Tensor &dout, // total_q x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp - std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional &dbias_, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const int max_seqlen_q, - const int max_seqlen_k, // max sequence length to choose the kernel - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - const float softcap, - const bool deterministic -) { - - #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); - #endif - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); - TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); - TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); - CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - CHECK_CONTIGUOUS(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_k); - - const auto sizes = q.sizes(); - auto opts = q.options(); - - const int total_q = sizes[0]; - const int batch_size = cu_seqlens_q.numel() - 1; - const int num_heads = sizes[1]; - const int head_size = sizes[2]; - const int total_k = k.size(0); - const int num_heads_k = k.size(1); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); - CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); - CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - - at::Tensor dq, dk, dv, dbias; - if (dq_.has_value()) { - dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); - CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); - CHECK_SHAPE(dq, total_q, num_heads, head_size); - } else { - dq = torch::empty_like(q); - } - if (dk_.has_value()) { - dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); - CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); - CHECK_SHAPE(dk, total_k, num_heads_k, head_size); - } else { - dk = torch::empty_like(k); - } - if (dv_.has_value()) { - dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); - CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - CHECK_SHAPE(dv, total_k, num_heads_k, head_size); - } else { - dv = torch::empty_like(v); - } - if (dbias_.has_value()) { - dbias = dbias_.value(); - TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); - CHECK_DEVICE(dbias); - TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - CHECK_SHAPE(dbias, total_q, num_heads_k, max_seqlen_k); - } else { - dbias = torch::empty({total_q, num_heads_k, max_seqlen_k}, opts); - } - - // bool loop = max_seqlen_k > blocksize_c; - // TODO: change later, for now set to true for simplicity - bool loop = true; +// TODO: At present, we don't have a good strategy to handle the mask and bias of the varlen variant. +// std::vector +// mha_varlen_bwd( +// const at::Tensor &dout, // total_q x num_heads, x head_size +// const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +// const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i +// const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i +// const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k +// const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k +// const at::Tensor &out, // total_q x num_heads x head_size +// const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp +// std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +// std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i +// std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i +// std::optional &dbias_, // total_q x num_heads_k x max_seqlen_k +// const at::Tensor &cu_seqlens_q, // b+1 +// const at::Tensor &cu_seqlens_k, // b+1 +// const int max_seqlen_q, +// const int max_seqlen_k, // max sequence length to choose the kernel +// const float softmax_scale, +// const bool zero_tensors, +// const bool is_causal, +// const float softcap, +// const bool deterministic +// ) { + +// #ifdef FLASHATTENTION_DISABLE_BACKWARD +// TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); +// #endif + +// // Otherwise the kernel will be launched from cuda:0 device +// at::cuda::CUDAGuard device_guard{q.device()}; + +// auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); +// bool is_sm8x_min = cc_major >= 8; +// TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + +// auto stream = at::cuda::getCurrentCUDAStream().stream(); + +// auto q_dtype = q.dtype(); +// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); +// TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); +// TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); +// TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); +// TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); +// TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); +// TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); +// TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); +// TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + +// CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); +// CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); +// CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + +// TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); +// TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); +// CHECK_CONTIGUOUS(cu_seqlens_q); +// CHECK_CONTIGUOUS(cu_seqlens_k); + +// const auto sizes = q.sizes(); +// auto opts = q.options(); + +// const int total_q = sizes[0]; +// const int batch_size = cu_seqlens_q.numel() - 1; +// const int num_heads = sizes[1]; +// const int head_size = sizes[2]; +// const int total_k = k.size(0); +// const int num_heads_k = k.size(1); +// TORCH_CHECK(batch_size > 0, "batch size must be positive"); +// TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); +// TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); +// TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + +// auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; +// const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); +// const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); +// const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + +// CHECK_SHAPE(q, total_q, num_heads, head_size); +// CHECK_SHAPE(k, total_k, num_heads_k, head_size); +// CHECK_SHAPE(v, total_k, num_heads_k, head_size); +// CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); +// CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); +// CHECK_SHAPE(out, total_q, num_heads, head_size); +// CHECK_SHAPE(dout, total_q, num_heads, head_size); +// CHECK_SHAPE(cu_seqlens_q, batch_size + 1); +// CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + +// at::Tensor dq, dk, dv, dbias; +// if (dq_.has_value()) { +// dq = dq_.value(); +// TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); +// CHECK_DEVICE(dq); +// TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); +// CHECK_SHAPE(dq, total_q, num_heads, head_size); +// } else { +// dq = torch::empty_like(q); +// } +// if (dk_.has_value()) { +// dk = dk_.value(); +// TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); +// CHECK_DEVICE(dk); +// TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); +// CHECK_SHAPE(dk, total_k, num_heads_k, head_size); +// } else { +// dk = torch::empty_like(k); +// } +// if (dv_.has_value()) { +// dv = dv_.value(); +// TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); +// CHECK_DEVICE(dv); +// TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); +// CHECK_SHAPE(dv, total_k, num_heads_k, head_size); +// } else { +// dv = torch::empty_like(v); +// } +// if (dbias_.has_value()) { +// dbias = dbias_.value(); +// TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); +// CHECK_DEVICE(dbias); +// TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); +// CHECK_SHAPE(dbias, total_q, num_heads_k, max_seqlen_k); +// } else { +// dbias = torch::empty({total_q, num_heads_k, max_seqlen_k}, opts); +// } + +// // bool loop = max_seqlen_k > blocksize_c; +// // TODO: change later, for now set to true for simplicity +// bool loop = true; - auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - if (loop) { - // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) - // because that would be too large if there is a very long sequence and the rest of the sequences are short. - // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). - // Note that 128 is the max block size on the seqlen_q dimension. - // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to - // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will - // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally - // allowed to do. So we won't have to do any bound checking, and performance should stay the same. - // Same holds for softmax_d, since LSE is stored in unpadded format. - if (!deterministic) { - dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - } else { - const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); - dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - } - } - - at::Tensor dk_expanded, dv_expanded, dbias_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); - dbias_expanded = torch::empty({total_q, num_heads, max_seqlen_k}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - dbias_expanded = dbias; - } - - if( zero_tensors ) { - dq.zero_(); - dk_expanded.zero_(); - dv_expanded.zero_(); - dbias_expanded.zero_(); - softmax_d.zero_(); - } - - Flash_bwd_params params; - - set_params_dgrad( - params, - batch_size, - max_seqlen_q, max_seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q, k, v, mask, bias, out, - dout, dq, dk_expanded, dv_expanded, dbias_expanded, - cu_seqlens_q.data_ptr(), - cu_seqlens_k.data_ptr(), - loop ? dq_accum.data_ptr() : nullptr, - nullptr, - nullptr, - softmax_lse.data_ptr(), - softmax_d.data_ptr(), - softmax_scale, - is_causal, - softcap, - deterministic, - /*unpadded_lse*/true - ); - params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); - params.total_q = total_q; - - auto launch = &run_mha_bwd; - - if (max_seqlen_q > 0) { - launch(params, stream); - } else { - // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk_expanded.zero_(); - dv_expanded.zero_(); - dbias_expanded.zero_(); - softmax_d.zero_(); - } - - // For MQA/GQA we need to sum dK and dV across the groups - if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); - at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); - at::sum_out(dbias, at::reshape(dbias_expanded, {total_q, num_heads_k, num_heads / num_heads_k, max_seqlen_k}), {2}); - } - - return { dq, dk, dv, dbias, softmax_d }; -} +// auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); +// at::Tensor dq_accum; +// if (loop) { +// // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) +// // because that would be too large if there is a very long sequence and the rest of the sequences are short. +// // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). +// // Note that 128 is the max block size on the seqlen_q dimension. +// // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to +// // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will +// // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally +// // allowed to do. So we won't have to do any bound checking, and performance should stay the same. +// // Same holds for softmax_d, since LSE is stored in unpadded format. +// if (!deterministic) { +// dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); +// } else { +// const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); +// dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); +// } +// } + +// at::Tensor dk_expanded, dv_expanded, dbias_expanded; +// if (num_heads_k != num_heads) { // MQA / GQA +// dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); +// dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); +// dbias_expanded = torch::empty({total_q, num_heads, max_seqlen_k}, opts); +// } else { +// dk_expanded = dk; +// dv_expanded = dv; +// dbias_expanded = dbias; +// } + +// if( zero_tensors ) { +// dq.zero_(); +// dk_expanded.zero_(); +// dv_expanded.zero_(); +// dbias_expanded.zero_(); +// softmax_d.zero_(); +// } + +// Flash_bwd_params params; + +// set_params_dgrad( +// params, +// batch_size, +// max_seqlen_q, max_seqlen_k, +// seqlen_q_rounded, seqlen_k_rounded, +// num_heads, num_heads_k, +// head_size, head_size_rounded, +// q, k, v, mask, bias, out, +// dout, dq, dk_expanded, dv_expanded, dbias_expanded, +// cu_seqlens_q.data_ptr(), +// cu_seqlens_k.data_ptr(), +// loop ? dq_accum.data_ptr() : nullptr, +// nullptr, +// nullptr, +// softmax_lse.data_ptr(), +// softmax_d.data_ptr(), +// softmax_scale, +// is_causal, +// softcap, +// deterministic, +// /*unpadded_lse*/true +// ); +// params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); +// params.total_q = total_q; + +// auto launch = &run_mha_bwd; + +// if (max_seqlen_q > 0) { +// launch(params, stream); +// } else { +// // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. +// dk_expanded.zero_(); +// dv_expanded.zero_(); +// dbias_expanded.zero_(); +// softmax_d.zero_(); +// } + +// // For MQA/GQA we need to sum dK and dV across the groups +// if (num_heads_k != num_heads) { +// at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); +// at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); +// at::sum_out(dbias, at::reshape(dbias_expanded, {total_q, num_heads_k, num_heads / num_heads_k, max_seqlen_k}), {2}); +// } + +// return { dq, dk, dv, dbias, softmax_d }; +// } } // namespace FLASH_NAMESPACE From 105026125ae73da34a394683dd1cf4ce447e50f8 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 13 Sep 2025 18:38:21 +0800 Subject: [PATCH 5/7] Fix formatting in dynamic_mask_attention_python function --- benchmarks/backward_equivalence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 024674f..804433f 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -191,7 +191,7 @@ def dynamic_mask_attention_python( value_states = repeat_kv(value_states, num_queries_per_kv) attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) - + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization From 9cf0f040305e7d867ede4105010549738d90c21e Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 13 Sep 2025 18:38:48 +0800 Subject: [PATCH 6/7] Updates attention mask and bias documentation for MQA/GQA Clarifies that attention mask and bias parameters support multiple tensor shapes to accommodate Multi-Query Attention (MQA) and Grouped Query Attention (GQA) patterns, in addition to the standard multi-head attention format. Adds explicit documentation for supported shapes including broadcast-compatible dimensions for flexible attention implementations. --- flash_dmattn/flash_dmattn_interface.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index c4ac38d..733b6e1 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -361,10 +361,14 @@ def flash_dmattn_func( key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) attn_mask: torch.Tensor, optional. The attention mask boolean tensor of - shape (batch_size, nheads_k, seqlen_q, seqlen_k) to apply to the attention scores. + shape (batch_size, nheads, seqlen_q, seqlen_k) to apply to the attention scores. + Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or + (batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA. If None, no mask is applied. attn_bias: torch.Tensor, optional. The attention bias float tensor of - shape (batch_size, nheads_k, seqlen_q, seqlen_k) to add to the attention scores. + shape (batch_size, nheads, seqlen_q, seqlen_k) to add to the attention scores. + Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or + (batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA. If None, no bias is applied. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. From ccfd3ec81c715b291eea34d33ce4feb9e88b89c8 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 13 Sep 2025 18:39:08 +0800 Subject: [PATCH 7/7] Updates attention tensor shape documentation for MQA/GQA Clarifies that attention mask and bias tensors support multiple shape formats to accommodate Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) patterns in addition to the standard multi-head attention format. Adds explicit documentation for supported shapes: standard num_heads format, num_kv_heads format, and broadcast-compatible single head format. --- flash_dmattn/integrations/flash_dynamic_mask_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index 7d718e2..898950b 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward( query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim). key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len). - attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA. + attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA. scaling (Optional[float]): The scaling factor for the attention scores. softcap (Optional[float]): The softcap value for the attention scores. **kwargs: Additional keyword arguments.