diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 024674f..804433f 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -191,7 +191,7 @@ def dynamic_mask_attention_python( value_states = repeat_kv(value_states, num_queries_per_kv) attn_mask = repeat_kv(attn_mask, num_queries_per_kv) attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) - + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index b4a657e..4108cb4 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -31,6 +31,8 @@ void set_params_fprop( const size_t seqlen_k_rounded, const size_t h, const size_t h_k, + const size_t h_mask, + const size_t h_bias, const size_t d, const size_t d_rounded, // device pointers @@ -107,6 +109,8 @@ void set_params_fprop( params.h = h; params.h_k = h_k; params.h_h_k_ratio = h / h_k; + params.h_mask = h_mask; + params.h_bias = h_bias; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; @@ -150,6 +154,8 @@ void set_params_dgrad( const size_t seqlen_k_rounded, const size_t h, const size_t h_k, + const size_t h_mask, + const size_t h_bias, const size_t d, const size_t d_rounded, // device pointers @@ -179,7 +185,7 @@ void set_params_dgrad( ) { set_params_fprop( params, - b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, h_mask, h_bias, d, d_rounded, q, k, v, mask, bias, out, cu_seqlens_q_d, cu_seqlens_k_d, @@ -341,8 +347,8 @@ mha_fwd( at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k + at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k + at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k, or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float softmax_scale, bool is_causal, @@ -380,10 +386,14 @@ mha_fwd( const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); + int num_heads_mask = mask.size(1); + int num_heads_bias = bias.size(1); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } @@ -392,12 +402,26 @@ mha_fwd( // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; const int ngroups = num_heads / num_heads_k; - at::Tensor mask_view = mask; - at::Tensor bias_view = bias; + const int orig_num_heads_mask = num_heads_mask; + const int orig_num_heads_bias = num_heads_bias; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); - mask_view = mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}); - bias_view = bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}); + if (num_heads_mask == 1) { + mask = mask.expand({batch_size, 1, ngroups, seqlen_k}); + } else if (num_heads_mask == num_heads_k) { + mask = mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}); + } else { // num_heads_mask == num_heads + mask = mask.reshape({batch_size, num_heads_k, ngroups, seqlen_k}); + } + if (num_heads_bias == 1) { + bias = bias.expand({batch_size, 1, ngroups, seqlen_k}); + } else if (num_heads_bias == num_heads_k) { + bias = bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}); + } else { // num_heads_bias == num_heads + bias = bias.reshape({batch_size, num_heads_k, ngroups, seqlen_k}); + } + num_heads_mask = (num_heads_mask == num_heads) ? num_heads_k : num_heads_mask; + num_heads_bias = (num_heads_bias == num_heads) ? num_heads_k : num_heads_bias; seqlen_q = ngroups; num_heads = num_heads_k; } @@ -405,8 +429,20 @@ mha_fwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(mask_view, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(bias_view, batch_size, num_heads_k, seqlen_q, seqlen_k); + if (num_heads_mask == 1) { + CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_mask == num_heads_k) { + CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); + } + if (num_heads_bias == 1) { + CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); + } at::Tensor out; if (out_.has_value()) { @@ -444,9 +480,9 @@ mha_fwd( batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_mask, num_heads_bias, head_size, head_size_rounded, - q, k, v, mask_view, bias_view, out, + q, k, v, mask, bias, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, @@ -477,231 +513,242 @@ mha_fwd( out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); - } - return {out, softmax_lse, p}; -} - -std::vector -mha_varlen_fwd( - at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k - std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - std::optional &leftpad_k_, // batch_size - std::optional &block_table_, // batch_size x max_num_blocks_per_seq - int max_seqlen_q, - const int max_seqlen_k, - const float softmax_scale, - const bool zero_tensors, - bool is_causal, - const float softcap, - const bool return_softmax -) { - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); - TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); - CHECK_DEVICE(cu_seqlens_q); - CHECK_DEVICE(cu_seqlens_k); - - at::Tensor block_table; - // const bool paged_KV = block_table_.has_value(); - const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. - if (paged_KV) { - block_table = block_table_.value(); - CHECK_DEVICE(block_table); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); - TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); - } - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - CHECK_CONTIGUOUS(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_k); - - const auto sizes = q.sizes(); - - const int batch_size = cu_seqlens_q.numel() - 1; - int num_heads = sizes[1]; - const int head_size = sizes[2]; - const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - - const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); - const int num_blocks = !paged_KV ? 0 : k.size(0); - const int page_block_size = !paged_KV ? 1 : k.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); - - if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - - void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); - - // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case - // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; - const int ngroups = num_heads / num_heads_k; - if (seqlenq_ngroups_swapped) { - q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); - max_seqlen_q = ngroups; - num_heads = num_heads_k; - cu_seqlens_q_d = nullptr; - } - - const int total_q = q.sizes()[0]; - - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); - TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - CHECK_SHAPE(q, total_q, num_heads, head_size); - if (!paged_KV) { - const int total_k = k.size(0); - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); - CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); - CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); - } else { - CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); - CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); - } - - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - if (seqused_k.has_value()){ - auto seqused_k_ = seqused_k.value(); - TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); - TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); - TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); - CHECK_SHAPE(seqused_k_, batch_size); - } - - at::Tensor out; - if (out_.has_value()) { - out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); - CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, sizes[0], sizes[1], head_size); - if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + if (orig_num_heads_mask == 1 || orig_num_heads_mask == num_heads_k) { + mask = mask.narrow(2, 0, 1); + } else { // orig_num_heads_mask == num_heads + mask = mask.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + } + if (orig_num_heads_bias == 1 || orig_num_heads_bias == num_heads_k) { + bias = bias.narrow(2, 0, 1); + } else { // orig_num_heads_bias == num_heads + bias = bias.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); } - } else { - out = torch::empty_like(q); - } - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - - auto opts = q.options(); - auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); - at::Tensor p; - - if (return_softmax) { - p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } else { - p = torch::empty({ 0 }, opts); - } - - if (zero_tensors) { - out.zero_(); - softmax_lse.fill_(-std::numeric_limits::infinity()); - if (return_softmax) { p.zero_(); } - } - - Flash_fwd_params params; - set_params_fprop( - params, - batch_size, - max_seqlen_q, max_seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q, k, v, mask, bias, out, - cu_seqlens_q_d, - cu_seqlens_k.data_ptr(), - seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, - return_softmax ? p.data_ptr() : nullptr, - softmax_lse.data_ptr(), - softmax_scale, - is_causal, - softcap, - seqlenq_ngroups_swapped, - /*unpadded_lse*/true - ); - params.total_q = total_q; - - if (paged_KV) { - params.block_table = block_table.data_ptr(); - params.block_table_batch_stride = block_table.stride(0); - params.k_batch_stride = k.stride(0); - params.v_batch_stride = v.stride(0); - } - params.page_block_size = page_block_size; - // Keep references to these tensors to extend their lifetime - at::Tensor softmax_lse_accum, out_accum; - if (seqlenq_ngroups_swapped) { - // Only apply split-k for decoding - std::tie(softmax_lse_accum, out_accum) = - set_params_splitkv( - params, batch_size, num_heads, head_size, - max_seqlen_k, max_seqlen_q, head_size_rounded, - /*num_splits*/ 0, get_num_sm(get_current_device()), opts - ); - } - - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_DEVICE(leftpad_k); - CHECK_CONTIGUOUS(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - params.leftpad_k = static_cast(leftpad_k.data_ptr()); - } - - if (max_seqlen_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream, paged_KV); - } else { - // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. - out.zero_(); - softmax_lse.fill_(std::numeric_limits::infinity()); - } - - if (seqlenq_ngroups_swapped) { - int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; - int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; - out = out.reshape(size_before).transpose(1, 2).reshape(size_after); - q = q.reshape(size_before).transpose(1, 2).reshape(size_after); - softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } - return {out, softmax_lse, p}; } +// TODO: At present, we don't have a good strategy to handle the mask and bias of the varlen variant. +// std::vector +// mha_varlen_fwd( +// at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +// const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. +// const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. +// const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k +// const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k +// std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i +// const at::Tensor &cu_seqlens_q, // b+1 +// const at::Tensor &cu_seqlens_k, // b+1 +// std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. +// std::optional &leftpad_k_, // batch_size +// std::optional &block_table_, // batch_size x max_num_blocks_per_seq +// int max_seqlen_q, +// const int max_seqlen_k, +// const float softmax_scale, +// const bool zero_tensors, +// bool is_causal, +// const float softcap, +// const bool return_softmax +// ) { +// // Otherwise the kernel will be launched from cuda:0 device +// at::cuda::CUDAGuard device_guard{q.device()}; +// auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); +// bool is_sm8x_min = cc_major >= 8; +// TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + +// auto q_dtype = q.dtype(); +// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); +// TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); +// TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); +// TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); +// TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); +// TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); +// TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + +// CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); +// CHECK_DEVICE(cu_seqlens_q); +// CHECK_DEVICE(cu_seqlens_k); + +// at::Tensor block_table; +// // const bool paged_KV = block_table_.has_value(); +// const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. +// if (paged_KV) { +// block_table = block_table_.value(); +// CHECK_DEVICE(block_table); +// TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); +// TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); +// } + +// TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// CHECK_CONTIGUOUS(cu_seqlens_q); +// CHECK_CONTIGUOUS(cu_seqlens_k); + +// const auto sizes = q.sizes(); + +// const int batch_size = cu_seqlens_q.numel() - 1; +// int num_heads = sizes[1]; +// const int head_size = sizes[2]; +// const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + +// const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); +// const int num_blocks = !paged_KV ? 0 : k.size(0); +// const int page_block_size = !paged_KV ? 1 : k.size(1); +// TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + +// if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + +// void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + +// // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case +// // H/t Daniel Haziza +// const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; +// const int ngroups = num_heads / num_heads_k; +// if (seqlenq_ngroups_swapped) { +// q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); +// max_seqlen_q = ngroups; +// num_heads = num_heads_k; +// cu_seqlens_q_d = nullptr; +// } + +// const int total_q = q.sizes()[0]; + +// TORCH_CHECK(batch_size > 0, "batch size must be positive"); +// TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); +// TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); +// TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + +// CHECK_SHAPE(q, total_q, num_heads, head_size); +// if (!paged_KV) { +// const int total_k = k.size(0); +// CHECK_SHAPE(k, total_k, num_heads_k, head_size); +// CHECK_SHAPE(v, total_k, num_heads_k, head_size); +// CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); +// CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); +// } else { +// CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); +// CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); +// CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); +// } + +// CHECK_SHAPE(cu_seqlens_q, batch_size + 1); +// CHECK_SHAPE(cu_seqlens_k, batch_size + 1); +// if (seqused_k.has_value()){ +// auto seqused_k_ = seqused_k.value(); +// TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); +// TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); +// TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); +// CHECK_SHAPE(seqused_k_, batch_size); +// } + +// at::Tensor out; +// if (out_.has_value()) { +// out = out_.value(); +// TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); +// CHECK_DEVICE(out); +// TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); +// CHECK_SHAPE(out, sizes[0], sizes[1], head_size); +// if (seqlenq_ngroups_swapped) { +// out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); +// } +// } else { +// out = torch::empty_like(q); +// } + +// auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; +// const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); +// const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); +// const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + +// auto opts = q.options(); +// auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); +// at::Tensor p; + +// if (return_softmax) { +// p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); +// } else { +// p = torch::empty({ 0 }, opts); +// } + +// if (zero_tensors) { +// out.zero_(); +// softmax_lse.fill_(-std::numeric_limits::infinity()); +// if (return_softmax) { p.zero_(); } +// } + +// Flash_fwd_params params; +// set_params_fprop( +// params, +// batch_size, +// max_seqlen_q, max_seqlen_k, +// seqlen_q_rounded, seqlen_k_rounded, +// num_heads, num_heads_k, +// head_size, head_size_rounded, +// q, k, v, mask, bias, out, +// cu_seqlens_q_d, +// cu_seqlens_k.data_ptr(), +// seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, +// return_softmax ? p.data_ptr() : nullptr, +// softmax_lse.data_ptr(), +// softmax_scale, +// is_causal, +// softcap, +// seqlenq_ngroups_swapped, +// /*unpadded_lse*/true +// ); +// params.total_q = total_q; + +// if (paged_KV) { +// params.block_table = block_table.data_ptr(); +// params.block_table_batch_stride = block_table.stride(0); +// params.k_batch_stride = k.stride(0); +// params.v_batch_stride = v.stride(0); +// } +// params.page_block_size = page_block_size; +// // Keep references to these tensors to extend their lifetime +// at::Tensor softmax_lse_accum, out_accum; +// if (seqlenq_ngroups_swapped) { +// // Only apply split-k for decoding +// std::tie(softmax_lse_accum, out_accum) = +// set_params_splitkv( +// params, batch_size, num_heads, head_size, +// max_seqlen_k, max_seqlen_q, head_size_rounded, +// /*num_splits*/ 0, get_num_sm(get_current_device()), opts +// ); +// } + +// if (leftpad_k_.has_value()) { +// auto leftpad_k = leftpad_k_.value(); +// TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); +// TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); +// CHECK_DEVICE(leftpad_k); +// CHECK_CONTIGUOUS(leftpad_k); +// CHECK_SHAPE(leftpad_k, batch_size); +// params.leftpad_k = static_cast(leftpad_k.data_ptr()); +// } + +// if (max_seqlen_k > 0) { +// auto stream = at::cuda::getCurrentCUDAStream().stream(); +// run_mha_fwd(params, stream, paged_KV); +// } else { +// // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. +// out.zero_(); +// softmax_lse.fill_(std::numeric_limits::infinity()); +// } + +// if (seqlenq_ngroups_swapped) { +// int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; +// int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; +// out = out.reshape(size_before).transpose(1, 2).reshape(size_after); +// q = q.reshape(size_before).transpose(1, 2).reshape(size_after); +// softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); +// } + +// return {out, softmax_lse, p}; +// } + void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { @@ -718,8 +765,8 @@ mha_bwd( const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k + const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x num_heads x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size @@ -774,10 +821,14 @@ mha_bwd( const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); + const int num_heads_mask = mask.size(1); + const int num_heads_bias = bias.size(1); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); @@ -787,8 +838,20 @@ mha_bwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + if (num_heads_mask == 1) { + CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_mask == num_heads_k) { + CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); + } + if (num_heads_bias == 1) { + CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); + } CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); @@ -825,9 +888,21 @@ mha_bwd( TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); CHECK_DEVICE(dbias); TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - CHECK_SHAPE(dbias, batch_size, num_heads_k, seqlen_q, seqlen_k); + if (num_heads_bias == 1) { + CHECK_SHAPE(dbias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(dbias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(dbias, batch_size, num_heads, seqlen_q, seqlen_k); + } } else { - dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); + if (num_heads_bias == 1) { + dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); + } else if (num_heads_bias == num_heads_k) { + dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); + } else { + dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } } // bool loop = seqlen_k > blocksize_c; @@ -852,10 +927,13 @@ mha_bwd( if (num_heads_k != num_heads) { // MQA / GQA dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dbias_expanded = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); } else { dk_expanded = dk; dv_expanded = dv; + } + if (num_heads_bias != num_heads) { + dbias_expanded = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } else { dbias_expanded = dbias; } @@ -866,7 +944,7 @@ mha_bwd( batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, + num_heads, num_heads_k, num_heads_mask, num_heads_bias, head_size, head_size_rounded, q, k, v, mask, bias, out, dout, dq, dk_expanded, dv_expanded, dbias_expanded, @@ -903,233 +981,237 @@ mha_bwd( if (num_heads_k != num_heads) { at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_k, num_heads / num_heads_k, seqlen_q, seqlen_k}), {2}); + } + // For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads + if (num_heads_bias != num_heads) { + at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}); } return { dq, dk, dv, dbias, softmax_d }; } -std::vector -mha_varlen_bwd( - const at::Tensor &dout, // total_q x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp - std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional &dbias_, // total_q x num_heads_k x max_seqlen_k - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const int max_seqlen_q, - const int max_seqlen_k, // max sequence length to choose the kernel - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - const float softcap, - const bool deterministic -) { - - #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); - #endif - - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); - bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); - TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); - TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); - CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - CHECK_CONTIGUOUS(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_k); - - const auto sizes = q.sizes(); - auto opts = q.options(); - - const int total_q = sizes[0]; - const int batch_size = cu_seqlens_q.numel() - 1; - const int num_heads = sizes[1]; - const int head_size = sizes[2]; - const int total_k = k.size(0); - const int num_heads_k = k.size(1); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); - const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - - CHECK_SHAPE(q, total_q, num_heads, head_size); - CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); - CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); - CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); - CHECK_SHAPE(out, total_q, num_heads, head_size); - CHECK_SHAPE(dout, total_q, num_heads, head_size); - CHECK_SHAPE(cu_seqlens_q, batch_size + 1); - CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - - at::Tensor dq, dk, dv, dbias; - if (dq_.has_value()) { - dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); - CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); - CHECK_SHAPE(dq, total_q, num_heads, head_size); - } else { - dq = torch::empty_like(q); - } - if (dk_.has_value()) { - dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); - CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); - CHECK_SHAPE(dk, total_k, num_heads_k, head_size); - } else { - dk = torch::empty_like(k); - } - if (dv_.has_value()) { - dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); - CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); - CHECK_SHAPE(dv, total_k, num_heads_k, head_size); - } else { - dv = torch::empty_like(v); - } - if (dbias_.has_value()) { - dbias = dbias_.value(); - TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); - CHECK_DEVICE(dbias); - TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - CHECK_SHAPE(dbias, total_q, num_heads_k, max_seqlen_k); - } else { - dbias = torch::empty({total_q, num_heads_k, max_seqlen_k}, opts); - } - - // bool loop = max_seqlen_k > blocksize_c; - // TODO: change later, for now set to true for simplicity - bool loop = true; +// TODO: At present, we don't have a good strategy to handle the mask and bias of the varlen variant. +// std::vector +// mha_varlen_bwd( +// const at::Tensor &dout, // total_q x num_heads, x head_size +// const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +// const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i +// const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i +// const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k +// const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k +// const at::Tensor &out, // total_q x num_heads x head_size +// const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp +// std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i +// std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i +// std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i +// std::optional &dbias_, // total_q x num_heads_k x max_seqlen_k +// const at::Tensor &cu_seqlens_q, // b+1 +// const at::Tensor &cu_seqlens_k, // b+1 +// const int max_seqlen_q, +// const int max_seqlen_k, // max sequence length to choose the kernel +// const float softmax_scale, +// const bool zero_tensors, +// const bool is_causal, +// const float softcap, +// const bool deterministic +// ) { + +// #ifdef FLASHATTENTION_DISABLE_BACKWARD +// TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward."); +// #endif + +// // Otherwise the kernel will be launched from cuda:0 device +// at::cuda::CUDAGuard device_guard{q.device()}; + +// auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); +// bool is_sm8x_min = cc_major >= 8; +// TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); + +// auto stream = at::cuda::getCurrentCUDAStream().stream(); + +// auto q_dtype = q.dtype(); +// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); +// TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); +// TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); +// TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); +// TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); +// TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); +// TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); +// TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); +// TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + +// CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); +// CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); +// CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + +// TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); +// TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); +// TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); +// CHECK_CONTIGUOUS(cu_seqlens_q); +// CHECK_CONTIGUOUS(cu_seqlens_k); + +// const auto sizes = q.sizes(); +// auto opts = q.options(); + +// const int total_q = sizes[0]; +// const int batch_size = cu_seqlens_q.numel() - 1; +// const int num_heads = sizes[1]; +// const int head_size = sizes[2]; +// const int total_k = k.size(0); +// const int num_heads_k = k.size(1); +// TORCH_CHECK(batch_size > 0, "batch size must be positive"); +// TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); +// TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); +// TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + +// auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; +// const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); +// const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); +// const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + +// CHECK_SHAPE(q, total_q, num_heads, head_size); +// CHECK_SHAPE(k, total_k, num_heads_k, head_size); +// CHECK_SHAPE(v, total_k, num_heads_k, head_size); +// CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); +// CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); +// CHECK_SHAPE(out, total_q, num_heads, head_size); +// CHECK_SHAPE(dout, total_q, num_heads, head_size); +// CHECK_SHAPE(cu_seqlens_q, batch_size + 1); +// CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + +// at::Tensor dq, dk, dv, dbias; +// if (dq_.has_value()) { +// dq = dq_.value(); +// TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); +// CHECK_DEVICE(dq); +// TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); +// CHECK_SHAPE(dq, total_q, num_heads, head_size); +// } else { +// dq = torch::empty_like(q); +// } +// if (dk_.has_value()) { +// dk = dk_.value(); +// TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); +// CHECK_DEVICE(dk); +// TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); +// CHECK_SHAPE(dk, total_k, num_heads_k, head_size); +// } else { +// dk = torch::empty_like(k); +// } +// if (dv_.has_value()) { +// dv = dv_.value(); +// TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); +// CHECK_DEVICE(dv); +// TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); +// CHECK_SHAPE(dv, total_k, num_heads_k, head_size); +// } else { +// dv = torch::empty_like(v); +// } +// if (dbias_.has_value()) { +// dbias = dbias_.value(); +// TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); +// CHECK_DEVICE(dbias); +// TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); +// CHECK_SHAPE(dbias, total_q, num_heads_k, max_seqlen_k); +// } else { +// dbias = torch::empty({total_q, num_heads_k, max_seqlen_k}, opts); +// } + +// // bool loop = max_seqlen_k > blocksize_c; +// // TODO: change later, for now set to true for simplicity +// bool loop = true; - auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - if (loop) { - // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) - // because that would be too large if there is a very long sequence and the rest of the sequences are short. - // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). - // Note that 128 is the max block size on the seqlen_q dimension. - // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to - // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will - // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally - // allowed to do. So we won't have to do any bound checking, and performance should stay the same. - // Same holds for softmax_d, since LSE is stored in unpadded format. - if (!deterministic) { - dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - } else { - const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); - dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - } - } - - at::Tensor dk_expanded, dv_expanded, dbias_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); - dbias_expanded = torch::empty({total_q, num_heads, max_seqlen_k}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - dbias_expanded = dbias; - } - - if( zero_tensors ) { - dq.zero_(); - dk_expanded.zero_(); - dv_expanded.zero_(); - dbias_expanded.zero_(); - softmax_d.zero_(); - } - - Flash_bwd_params params; - - set_params_dgrad( - params, - batch_size, - max_seqlen_q, max_seqlen_k, - seqlen_q_rounded, seqlen_k_rounded, - num_heads, num_heads_k, - head_size, head_size_rounded, - q, k, v, mask, bias, out, - dout, dq, dk_expanded, dv_expanded, dbias_expanded, - cu_seqlens_q.data_ptr(), - cu_seqlens_k.data_ptr(), - loop ? dq_accum.data_ptr() : nullptr, - nullptr, - nullptr, - softmax_lse.data_ptr(), - softmax_d.data_ptr(), - softmax_scale, - is_causal, - softcap, - deterministic, - /*unpadded_lse*/true - ); - params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); - params.total_q = total_q; - - auto launch = &run_mha_bwd; - - if (max_seqlen_q > 0) { - launch(params, stream); - } else { - // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk_expanded.zero_(); - dv_expanded.zero_(); - dbias_expanded.zero_(); - softmax_d.zero_(); - } - - // For MQA/GQA we need to sum dK and dV across the groups - if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); - at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); - at::sum_out(dbias, at::reshape(dbias_expanded, {total_q, num_heads_k, num_heads / num_heads_k, max_seqlen_k}), {2}); - } - - return { dq, dk, dv, dbias, softmax_d }; -} +// auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); +// at::Tensor dq_accum; +// if (loop) { +// // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) +// // because that would be too large if there is a very long sequence and the rest of the sequences are short. +// // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). +// // Note that 128 is the max block size on the seqlen_q dimension. +// // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to +// // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will +// // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally +// // allowed to do. So we won't have to do any bound checking, and performance should stay the same. +// // Same holds for softmax_d, since LSE is stored in unpadded format. +// if (!deterministic) { +// dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); +// } else { +// const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); +// dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); +// } +// } + +// at::Tensor dk_expanded, dv_expanded, dbias_expanded; +// if (num_heads_k != num_heads) { // MQA / GQA +// dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); +// dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); +// dbias_expanded = torch::empty({total_q, num_heads, max_seqlen_k}, opts); +// } else { +// dk_expanded = dk; +// dv_expanded = dv; +// dbias_expanded = dbias; +// } + +// if( zero_tensors ) { +// dq.zero_(); +// dk_expanded.zero_(); +// dv_expanded.zero_(); +// dbias_expanded.zero_(); +// softmax_d.zero_(); +// } + +// Flash_bwd_params params; + +// set_params_dgrad( +// params, +// batch_size, +// max_seqlen_q, max_seqlen_k, +// seqlen_q_rounded, seqlen_k_rounded, +// num_heads, num_heads_k, +// head_size, head_size_rounded, +// q, k, v, mask, bias, out, +// dout, dq, dk_expanded, dv_expanded, dbias_expanded, +// cu_seqlens_q.data_ptr(), +// cu_seqlens_k.data_ptr(), +// loop ? dq_accum.data_ptr() : nullptr, +// nullptr, +// nullptr, +// softmax_lse.data_ptr(), +// softmax_d.data_ptr(), +// softmax_scale, +// is_causal, +// softcap, +// deterministic, +// /*unpadded_lse*/true +// ); +// params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); +// params.total_q = total_q; + +// auto launch = &run_mha_bwd; + +// if (max_seqlen_q > 0) { +// launch(params, stream); +// } else { +// // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. +// dk_expanded.zero_(); +// dv_expanded.zero_(); +// dbias_expanded.zero_(); +// softmax_d.zero_(); +// } + +// // For MQA/GQA we need to sum dK and dV across the groups +// if (num_heads_k != num_heads) { +// at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); +// at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); +// at::sum_out(dbias, at::reshape(dbias_expanded, {total_q, num_heads_k, num_heads / num_heads_k, max_seqlen_k}), {2}); +// } + +// return { dq, dk, dv, dbias, softmax_d }; +// } } // namespace FLASH_NAMESPACE diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index c1cb7f4..a533cc3 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -50,6 +50,9 @@ struct Mask_params { index_t mask_batch_stride; // Stride between batches of attention mask index_t mask_head_stride; // Stride between heads of attention mask index_t mask_row_stride; // Stride between rows of attention mask + + // The number of heads in the mask. + int h_mask; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -61,6 +64,9 @@ struct Bias_params { index_t bias_batch_stride; // Stride between batches of attention bias index_t bias_head_stride; // Stride between heads of attention bias index_t bias_row_stride; // Stride between rows of attention bias + + // The number of heads in the bias. + int h_bias; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 723265b..2ba7d2a 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -107,10 +107,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb) - + (bidh / params.h_h_k_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN; + + h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN; + const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb) - + (bidh / params.h_h_k_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN; + + h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN; const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb) + bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 79dcbfb..f79e477 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -136,6 +136,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // might save us 1 register (we just need n_block instead of both n_block and n_block_max). const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); + const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); // Global memory tensor configuration Tensor mQ = make_tensor( @@ -170,21 +172,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // (kBlockN, kHeadDim, nblocksN) Tensor mMask = make_tensor( make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), - make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); Tensor gMask = local_tile( - mMask(bidh / params.h_h_k_ratio, _, _), + mMask(h_idx_mask, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) Tensor mBias = make_tensor( make_gmem_ptr(reinterpret_cast(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)), - make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_shape(params.h_bias, binfo.actual_seqlen_q, binfo.actual_seqlen_k), make_stride(params.bias_head_stride, params.bias_row_stride, _1{}) ); Tensor gBias = local_tile( - mBias(bidh / params.h_h_k_ratio, _, _), + mBias(h_idx_bias, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) @@ -840,16 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t col_offset_mask = (block_table == nullptr) ? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN + + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; + + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; + const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t col_offset_bias = (block_table == nullptr) ? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN + + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache) - + (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; + + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; // Global memory tensor configuration Tensor mQ = make_tensor( diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index c4ac38d..733b6e1 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -361,10 +361,14 @@ def flash_dmattn_func( key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) attn_mask: torch.Tensor, optional. The attention mask boolean tensor of - shape (batch_size, nheads_k, seqlen_q, seqlen_k) to apply to the attention scores. + shape (batch_size, nheads, seqlen_q, seqlen_k) to apply to the attention scores. + Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or + (batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA. If None, no mask is applied. attn_bias: torch.Tensor, optional. The attention bias float tensor of - shape (batch_size, nheads_k, seqlen_q, seqlen_k) to add to the attention scores. + shape (batch_size, nheads, seqlen_q, seqlen_k) to add to the attention scores. + Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or + (batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA. If None, no bias is applied. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index 7d718e2..898950b 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward( query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim). key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len). - attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA. + attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA. scaling (Optional[float]): The scaling factor for the attention scores. softcap (Optional[float]): The softcap value for the attention scores. **kwargs: Additional keyword arguments.