From 8a60204b7ee7874d18638581e7a7d8561f1ee8f8 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 5 Oct 2025 14:28:52 +0800 Subject: [PATCH 01/15] Implement varlen MHA fwd with mask/bias support Adds a working variable-length attention forward path with boolean mask and additive bias, supporting broadcast across head dims. Enforces dtype/device/contiguity and shape checks, initializes outputs when requested, and handles empty key sequences. Keeps Paged KV disabled pending fixes, retains decoding optimization for single-token queries, and exposes optional softmax output for debugging/inspection. Improves usability on Ampere+ GPUs while maintaining correctness and constraints. --- csrc/flash_dmattn/flash_api.cpp | 401 ++++++++++++++++---------------- 1 file changed, 200 insertions(+), 201 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 215479e..33058cf 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -555,227 +555,226 @@ mha_fwd( return {out, softmax_lse, p}; } -// TODO: At present, we don't have a good strategy to handle the mask and bias of the varlen variant. -// std::vector -// mha_varlen_fwd( -// at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i -// const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. -// const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. -// const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k -// const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k -// std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i -// const at::Tensor &cu_seqlens_q, // b+1 -// const at::Tensor &cu_seqlens_k, // b+1 -// std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. -// std::optional &leftpad_k_, // batch_size -// std::optional &block_table_, // batch_size x max_num_blocks_per_seq -// int max_seqlen_q, -// const int max_seqlen_k, -// const float softmax_scale, -// const bool zero_tensors, -// bool is_causal, -// const float softcap, -// const bool return_softmax -// ) { -// // Otherwise the kernel will be launched from cuda:0 device -// at::cuda::CUDAGuard device_guard{q.device()}; -// auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); -// bool is_sm8x_min = cc_major >= 8; -// TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); +std::vector +mha_varlen_fwd( + at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &mask, // total_q x {1|num_heads_k|num_heads} x max_seqlen_k or total_k x {1|num_heads_k|num_heads} + const at::Tensor &bias, // total_q x {1|num_heads_k|num_heads} x max_seqlen_k or total_k x {1|num_heads_k|num_heads} + std::optional &out_, // total_q x num_heads x head_size + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + int max_seqlen_q, + const int max_seqlen_k, + const float softmax_scale, + const bool zero_tensors, + bool is_causal, + const float softcap, + const bool return_softmax +) { + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x_min = cc_major >= 8; + TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); -// auto q_dtype = q.dtype(); -// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); -// TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); -// TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); -// TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); -// TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); -// TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); -// TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); + TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + at::Tensor block_table; + // const bool paged_KV = block_table_.has_value(); + const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } -// CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); -// CHECK_DEVICE(cu_seqlens_q); -// CHECK_DEVICE(cu_seqlens_k); - -// at::Tensor block_table; -// // const bool paged_KV = block_table_.has_value(); -// const bool paged_KV = false; // TODO: Temporarily disable Paged KV, because some bugs are still being fixed. -// if (paged_KV) { -// block_table = block_table_.value(); -// CHECK_DEVICE(block_table); -// TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); -// TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); -// } + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); -// TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); -// CHECK_CONTIGUOUS(cu_seqlens_q); -// CHECK_CONTIGUOUS(cu_seqlens_k); + const auto sizes = q.sizes(); -// const auto sizes = q.sizes(); + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); -// const int batch_size = cu_seqlens_q.numel() - 1; -// int num_heads = sizes[1]; -// const int head_size = sizes[2]; -// const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - -// const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); -// const int num_blocks = !paged_KV ? 0 : k.size(0); -// const int page_block_size = !paged_KV ? 1 : k.size(1); -// TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); - -// if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case - -// void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); - -// // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case -// // H/t Daniel Haziza -// const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; -// const int ngroups = num_heads / num_heads_k; -// if (seqlenq_ngroups_swapped) { -// q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); -// max_seqlen_q = ngroups; -// num_heads = num_heads_k; -// cu_seqlens_q_d = nullptr; -// } + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? 1 : k.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); -// const int total_q = q.sizes()[0]; + if (max_seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case -// TORCH_CHECK(batch_size > 0, "batch size must be positive"); -// TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); -// TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); -// TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); -// CHECK_SHAPE(q, total_q, num_heads, head_size); -// if (!paged_KV) { -// const int total_k = k.size(0); -// CHECK_SHAPE(k, total_k, num_heads_k, head_size); -// CHECK_SHAPE(v, total_k, num_heads_k, head_size); -// CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); -// CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); -// } else { -// CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); -// CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); -// CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); -// } + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; + const int ngroups = num_heads / num_heads_k; + if (seqlenq_ngroups_swapped) { + q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + max_seqlen_q = ngroups; + num_heads = num_heads_k; + cu_seqlens_q_d = nullptr; + } -// CHECK_SHAPE(cu_seqlens_q, batch_size + 1); -// CHECK_SHAPE(cu_seqlens_k, batch_size + 1); -// if (seqused_k.has_value()){ -// auto seqused_k_ = seqused_k.value(); -// TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); -// TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); -// TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); -// CHECK_SHAPE(seqused_k_, batch_size); -// } + const int total_q = q.sizes()[0]; -// at::Tensor out; -// if (out_.has_value()) { -// out = out_.value(); -// TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); -// CHECK_DEVICE(out); -// TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); -// CHECK_SHAPE(out, sizes[0], sizes[1], head_size); -// if (seqlenq_ngroups_swapped) { -// out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); -// } -// } else { -// out = torch::empty_like(q); -// } + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); -// auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; -// const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); -// const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); -// const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + CHECK_SHAPE(q, total_q, num_heads, head_size); + if (!paged_KV) { + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k); + CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k); + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } -// auto opts = q.options(); -// auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); -// at::Tensor p; + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } -// if (return_softmax) { -// p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); -// } else { -// p = torch::empty({ 0 }, opts); -// } + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + } + } else { + out = torch::empty_like(q); + } -// if (zero_tensors) { -// out.zero_(); -// softmax_lse.fill_(-std::numeric_limits::infinity()); -// if (return_softmax) { p.zero_(); } -// } + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); -// Flash_fwd_params params; -// set_params_fprop( -// params, -// batch_size, -// max_seqlen_q, max_seqlen_k, -// seqlen_q_rounded, seqlen_k_rounded, -// num_heads, num_heads_k, -// head_size, head_size_rounded, -// q, k, v, mask, bias, out, -// cu_seqlens_q_d, -// cu_seqlens_k.data_ptr(), -// seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, -// return_softmax ? p.data_ptr() : nullptr, -// softmax_lse.data_ptr(), -// softmax_scale, -// is_causal, -// softcap, -// seqlenq_ngroups_swapped, -// /*unpadded_lse*/true -// ); -// params.total_q = total_q; + auto opts = q.options(); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + at::Tensor p; -// if (paged_KV) { -// params.block_table = block_table.data_ptr(); -// params.block_table_batch_stride = block_table.stride(0); -// params.k_batch_stride = k.stride(0); -// params.v_batch_stride = v.stride(0); -// } -// params.page_block_size = page_block_size; -// // Keep references to these tensors to extend their lifetime -// at::Tensor softmax_lse_accum, out_accum; -// if (seqlenq_ngroups_swapped) { -// // Only apply split-k for decoding -// std::tie(softmax_lse_accum, out_accum) = -// set_params_splitkv( -// params, batch_size, num_heads, head_size, -// max_seqlen_k, max_seqlen_q, head_size_rounded, -// /*num_splits*/ 0, get_num_sm(get_current_device()), opts -// ); -// } + if (return_softmax) { + p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + } else { + p = torch::empty({ 0 }, opts); + } -// if (leftpad_k_.has_value()) { -// auto leftpad_k = leftpad_k_.value(); -// TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); -// TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); -// CHECK_DEVICE(leftpad_k); -// CHECK_CONTIGUOUS(leftpad_k); -// CHECK_SHAPE(leftpad_k, batch_size); -// params.leftpad_k = static_cast(leftpad_k.data_ptr()); -// } + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) { p.zero_(); } + } -// if (max_seqlen_k > 0) { -// auto stream = at::cuda::getCurrentCUDAStream().stream(); -// run_mha_fwd(params, stream, paged_KV); -// } else { -// // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. -// out.zero_(); -// softmax_lse.fill_(std::numeric_limits::infinity()); -// } + Flash_fwd_params params; + set_params_fprop( + params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, mask, bias, out, + cu_seqlens_q_d, + cu_seqlens_k.data_ptr(), + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, + return_softmax ? p.data_ptr() : nullptr, + softmax_lse.data_ptr(), + softmax_scale, + is_causal, + softcap, + seqlenq_ngroups_swapped, + /*unpadded_lse*/true + ); + params.total_q = total_q; -// if (seqlenq_ngroups_swapped) { -// int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; -// int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; -// out = out.reshape(size_before).transpose(1, 2).reshape(size_after); -// q = q.reshape(size_before).transpose(1, 2).reshape(size_after); -// softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); -// } + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + } + params.page_block_size = page_block_size; + // Keep references to these tensors to extend their lifetime + at::Tensor softmax_lse_accum, out_accum; + if (seqlenq_ngroups_swapped) { + // Only apply split-k for decoding + std::tie(softmax_lse_accum, out_accum) = + set_params_splitkv( + params, batch_size, num_heads, head_size, + max_seqlen_k, max_seqlen_q, head_size_rounded, + /*num_splits*/ 0, get_num_sm(get_current_device()), opts + ); + } -// return {out, softmax_lse, p}; -// } + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream, paged_KV); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; + int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; + out = out.reshape(size_before).transpose(1, 2).reshape(size_after); + q = q.reshape(size_before).transpose(1, 2).reshape(size_after); + softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); + } + + return {out, softmax_lse, p}; +} void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { From 5076dbc5bdd167cbee6fb052dfa3ff49044377fe Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 6 Oct 2025 20:42:53 +0800 Subject: [PATCH 02/15] Zeroes NaN/Inf and makes bias grad contiguous MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enables sanitization in forward/backward to replace NaN and ±Inf with 0, improving numerical stability and preventing invalid value propagation. Stops treating bias gradient as an input and ensures the computed bias gradient is contiguous, avoiding layout issues in downstream kernels. --- flash_dmattn/flash_dmattn_interface.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 2c267d4..2db8ea4 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -95,7 +95,7 @@ def _flash_dmattn_forward( softcap, return_softmax, ) - # _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) + _sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0) return out, softmax_lse, S_dmask @@ -145,7 +145,7 @@ def _flash_dmattn_backward( softcap: float, deterministic: bool, ) -> torch.Tensor: - dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)] + dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] ( dq, dk, @@ -170,7 +170,7 @@ def _flash_dmattn_backward( softcap, deterministic, ) - # _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=torch.finfo(dq.dtype).max, neginf=torch.finfo(dq.dtype).min) + _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=0.0, neginf=0.0) return softmax_d @@ -193,7 +193,7 @@ def _flash_dmattn_backward_fake( softcap: float, deterministic: bool, ) -> torch.Tensor: - dout, dbias, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, dbias, q, k, v, mask, bias, out)] + dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] if dq is None: dq = torch.empty_like(q) if dk is None: @@ -288,7 +288,7 @@ def backward( ): q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) - dbias = torch.zeros_like(bias) if bias is not None else None + dbias = torch.zeros_like(bias).contiguous() if bias is not None else None head_size_og = dout.size(3) dout_padded = dout From 20eb46ec67fff1588e4c40184a8decd42769328b Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 6 Oct 2025 20:44:29 +0800 Subject: [PATCH 03/15] Remove unnecessary blank lines in mha_fwd and mha_bwd functions --- csrc/flash_dmattn/flash_api.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 33058cf..8e68df2 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -361,10 +361,8 @@ mha_fwd( const float softcap, const bool return_softmax ) { - // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); @@ -816,7 +814,6 @@ mha_bwd( // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; - auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); From 641383a9fdeb42a3bfc33ce2a9090e17f4eaa9f9 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 6 Oct 2025 22:36:07 +0800 Subject: [PATCH 04/15] Require flash-dmattn; deprecate past_key_value Removes the eager attention fallback and unconditionally routes attention through the flash dynamic mask backend, printing an install hint when the dependency is missing. Deprecates the past_key_value argument in favor of past_key_values across attention, layer, and LM APIs, switches types to a cache interface, and initializes the dynamic cache with config for correctness. Disables SDPA/flex/attention-backend support flags to reflect the single supported backend. Adds a buffer annotation to satisfy linting, drops unused decoder accessors, ignores unused attention weights, and updates a paper link to the HF mirror. Improves consistency of the attention API and enforces a single, performant backend. --- examples/modeling/modeling_doge.py | 71 +++++++++--------------------- 1 file changed, 20 insertions(+), 51 deletions(-) diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index 377dfb2..dd58c80 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -40,13 +40,14 @@ from transformers.modeling_utils import AttentionInterface, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available +from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import OutputRecorder, check_model_inputs from .configuration_doge import DogeConfig try: from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward except ImportError: - flash_dynamic_mask_attention_forward = None + print("Please install flash_dmattn to use this model: pip install flash-dmattn") if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask @@ -74,6 +75,8 @@ def extra_repr(self): class DogeRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + def __init__(self, config: DogeConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" @@ -153,34 +156,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - attention_bias: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_bias is not None: - attn_weights = attn_weights + attention_bias - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - class DogeAttention(nn.Module): def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): super().__init__() @@ -213,12 +188,13 @@ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -232,10 +208,10 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # sampling dt_states from value_states to generate attention bias dt_states = self.dt_proj( @@ -243,9 +219,7 @@ def forward( ) attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) - attention_interface: Callable = eager_attention_forward - if flash_dynamic_mask_attention_forward is not None: - attention_interface = flash_dynamic_mask_attention_forward + attention_interface: Callable = flash_dynamic_mask_attention_forward attn_output, attn_weights = attention_interface( self, @@ -348,13 +322,14 @@ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config) self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], @@ -362,12 +337,12 @@ def forward( # sequence transformation residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -395,10 +370,10 @@ class DogePreTrainedModel(PreTrainedModel): _no_split_modules = ["DogeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = False - _supports_sdpa = True - _supports_flex_attn = True + _supports_sdpa = False + _supports_flex_attn = False _can_compile_fullgraph = False - _supports_attention_backend = True + _supports_attention_backend = False _can_record_outputs = { "router_logits": OutputRecorder(DogeCDMoE, index=1), "hidden_states": DogeDecoderLayer, @@ -456,7 +431,7 @@ def forward( raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if use_cache and past_key_values is None: - past_key_values = DynamicCache() + past_key_values = DynamicCache(config=self.config) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -490,7 +465,7 @@ def forward( position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -514,7 +489,7 @@ def load_balancing_loss_func( r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. @@ -628,12 +603,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - @can_return_tuple @auto_docstring def forward( @@ -641,7 +610,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, From b36949514db9f30c6b7968fe1f41ebec08462116 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Tue, 7 Oct 2025 20:36:21 +0800 Subject: [PATCH 05/15] Add Chinese API reference documentation for Flash Dynamic Mask Attention --- docs/api_reference_zh.md | 409 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 409 insertions(+) create mode 100644 docs/api_reference_zh.md diff --git a/docs/api_reference_zh.md b/docs/api_reference_zh.md new file mode 100644 index 0000000..da640bc --- /dev/null +++ b/docs/api_reference_zh.md @@ -0,0 +1,409 @@ +# Flash Dynamic Mask Attention API 参考文档 + + +## 概述 + +Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。 + + +## 目录 + +1. [安装](#安装) +2. [快速开始](#快速开始) +3. [后端选择与比较](#后端选择与比较) +4. [接口函数详解](#接口函数详解) + - [CUDA 后端:flash_dmattn_func](#flash_dmattn_func-cuda-后端) + - [Triton 后端:triton_dmattn_func](#triton_dmattn_func-triton-后端) + - [Flex 后端:flex_dmattn_func](#flex_dmattn_func-flex-后端) +5. [集成](#集成) + - [Transformers 集成](#transformers-集成) +6. [常见问题与解决方案](#常见问题与解决方案) + + +## 安装 + +请参考 [README](https://github.com/SmallDoges/flash-dmattn/blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1) 以获取详细的安装说明和依赖项。 + +```bash +# 使用 CUDA 后端 +pip install flash-dmattn + +# 或从源码安装 +pip install -e . + +# 仅使用 Triton/Flex 后端 +FLASH_DMATTN_SKIP_CUDA_BUILD=1 pip install -e . +``` + + +## 快速开始 + +使用 `flash_dmattn_func_auto` 可以自动选择最佳可用后端,无需手动判断。 + +```python +import torch +from flash_dmattn import flash_dmattn_func_auto + +# 准备输入张量 +batch, seqlen, num_heads, head_dim = 2, 1024, 8, 64 +q = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') +k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') +v = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') + +# 获取注意力函数(自动选择后端,优先级: cuda > triton > flex) +attn_func = flash_dmattn_func_auto() + +# 调用注意力计算 +output = attn_func(q, k, v, is_causal=True) +print(f"输出形状: {output.shape}") # (2, 1024, 8, 64) + +# 也可以强制使用特定后端 +attn_func = flash_dmattn_func_auto(backend="cuda") # 或 "triton", "flex" +output = attn_func(q, k, v, is_causal=True) +``` + +> [!NOTE] +> `flash_dmattn_func_auto` 返回一个可调用的注意力函数,而不是注意力输出。 + + +## 后端选择与比较 + +### 可用后端检查 + +```python +from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE + +# 查看所有可用后端 +print(get_available_backends()) # 例如:["cuda", "triton", "flex"] + +# 检查特定后端是否可用 +print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABLE}") +``` + +### 后端特性对比 + +| 特性 | CUDA | Triton | Flex | +|------|------|--------|------| +| **性能** | 最高 | 良好 | 良好 | +| **内存效率** | 最佳 | 良好 | 良好 | +| **构建要求** | 自定义 CUDA 扩展 | triton 包 | transformers 包 | +| **GQA 支持** | ✅ | ✅ | ✅ | +| **注意力掩码** | ✅ | ✅ | ⚠️ | +| **注意力偏置** | ✅ | ✅ | ✅ | +| **因果掩码** | ✅ | ✅ | ✅ | +| **Softcap** | ✅ | ❌ | ❌ | +| **确定性** | ✅ | ❌ | ❌ | +| **返回注意力概率** | ✅ | ❌ | ❌ | +| **反向传播支持** | ✅ | ✅ | ⚠️ | + +> [!NOTE] +> ✅ 完全支持 | ⚠️ 有限支持 | ❌ 不支持 + +### 何时使用各个后端 + +**CUDA 后端** ([详细说明](#flash_dmattn_func-cuda-后端)) +- ✅ 完整梯度支持的训练工作负载 +- ✅ 最大性能生产推理 +- ✅ 需要确定性行为的应用 +- ❌ 避免:无法构建自定义 CUDA 扩展时 + +**Triton 后端** ([详细说明](#triton_dmattn_func-triton-后端)) +- ✅ CUDA 扩展不可用时的训练工作负载 +- ✅ 开发和原型设计 +- ✅ 跨平台兼容性需求 +- ✅ 性能和易安装性的良好平衡 + +**Flex 后端** ([详细说明](#flex_dmattn_func-flex-后端)) +- ✅ 仅推理应用 +- ✅ 使用最新 PyTorch 特性的研究 +- ✅ 无需自定义构建的快速实验 +- ❌ 避免:训练 +- ❌ 避免:需要严格的注意力掩码遵从时 + +### 导入可用函数 + +```python +from flash_dmattn import ( + # 自动后端选择 + get_available_backends, + flash_dmattn_func_auto, + + # 后端特定函数 + flash_dmattn_func, # CUDA 后端 + triton_dmattn_func, # Triton 后端 + flex_dmattn_func, # Flex 后端 + + # 后端可用性标志 + CUDA_AVAILABLE, + TRITON_AVAILABLE, + FLEX_AVAILABLE, +) + +# Transformers 集成 +from flash_dmattn.integrations.flash_dynamic_mask_attention import ( + flash_dynamic_mask_attention_forward +) +``` + + +## 接口函数详解 + +### flash_dmattn_func (CUDA 后端) + +主要的注意力函数。支持多头注意力和分组查询注意力(当 KV 头数少于 Q 头数时)。需要 CUDA 扩展已构建并可用。 + +```python +def flash_dmattn_func( + query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) + key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) + value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, # (batch, {num_heads, num_kv_heads, 1}, {seqlen_q, 0}, seqlen_k) + attn_bias: Optional[torch.Tensor] = None, # (batch, {num_heads, num_kv_heads, 1}, {seqlen_q, 0}, seqlen_k) + scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) + is_causal: Optional[bool] = None, # 因果掩码 + softcap: Optional[float] = None, # 仅 CUDA 支持 + deterministic: Optional[bool] = None, # 仅 CUDA 支持 + return_attn_probs: Optional[bool] = None, # 仅 CUDA 支持,用于测试 +) -> torch.Tensor +``` + +#### 参数 + +- query: (B, Q, H, D). CUDA 张量,fp16/bf16,最后一维连续 +- key: (B, K, H_kv, D). 与 query 相同的数据类型/设备;当 H_kv <= H 时为 GQA +- value: (B, K, H_kv, D). 与 query 相同的数据类型/设备;当 H_kv <= H 时为 GQA +- attn_mask: (B, {H, H_kv, 1}, {Q, 0}, K). 1.0 = 可见,0.0 = 被掩码。None 表示禁用 +- attn_bias: (B, {H, H_kv, 1}, {Q, 0}, K). 在 softmax 前加到分数上。None 表示禁用 +- scale: 分数缩放;默认为 1/sqrt(D) +- is_causal: 应用因果掩码 +- softcap, deterministic, return_attn_probs: 仅在 CUDA 后端有效;在其他后端被忽略 + +#### 返回值 + +- output: (B, Q, H, D) + +### triton_dmattn_func (Triton 后端) + +基于 Triton 的实现,无需自定义 CUDA 内核即可提供良好性能。 + +```python +def triton_dmattn_func( + query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) + key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) + value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + is_causal: bool = False, # 因果掩码 + scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) +) -> torch.Tensor +``` + +### flex_dmattn_func (Flex Attention 后端) + +基于 Flex Attention 的实现,使用 PyTorch 原生 flex attention 并支持动态掩码。 + +```python +def flex_dmattn_func( + query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) + key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) + value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + is_causal: Optional[bool] = None, # 因果掩码 + scale: Optional[float] = None, # 分数缩放,默认为 1/sqrt(head_dim) +) -> torch.Tensor +``` + + +## 集成 + +### Transformers 集成 + +为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash dynamic mask attention 支持。 + +#### flash_dynamic_mask_attention_forward + +```python +from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward + +def flash_dynamic_mask_attention_forward( + module: torch.nn.Module, # 注意力模块 + query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim) + key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) + value: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) + attention_mask: Optional[torch.Tensor], # (batch_size, {num_heads, num_kv_heads, 1}, {query_len, 0}, key_len) + attention_bias: Optional[torch.Tensor], # (batch_size, {num_heads, num_kv_heads, 1}, {query_len, 0}, key_len) + scaling: Optional[float] = None, # 分数缩放 + softcap: Optional[float] = None, # softcap 值 + **kwargs, +) -> tuple[torch.Tensor, None] +``` + +#### 参数 + +- module: 注意力模块实例 +- query: 查询张量 (B, H, Q, D) +- key: 键张量 (B, H_kv, K, D) +- value: 值张量 (B, H_kv, K, D) +- attention_mask: 布尔注意力掩码 (B, {H, H_kv, 1}, {Q, 0}, K) +- attention_bias: 加到分数上的注意力偏置 (B, {H, H_kv, 1}, {Q, 0}, K) +- scaling: 分数缩放因子 +- softcap: 注意力分数的 softcap 值 +- **kwargs: 额外参数,包括: + - is_causal: 是否应用因果掩码 + - keep_window_size: 保持的窗口大小 + - layer_idx: 用于日志的层索引 + - implementation: 使用的实现("flash_dmattn" 或 None) + +#### 返回值 + +- tuple[torch.Tensor, None]: 输出张量 (B, Q, H, D) 和 None(用于兼容性) + +#### 使用示例 + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Callable, tuple +from transformers.cache_utils import Cache +from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward + +class DynamicMaskAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.keep_window_size = config.keep_window_size + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.A = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self.dt_proj = nn.Linear( + config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin 和 cos 是 RoPE 模型特有的;static cache 需要 cache_position + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # 从 value_states 采样 dt_states 以生成 attention_bias + dt_states = self.dt_proj( + value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) + ) + attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) + + # 选择注意力实现 + attention_interface: Callable = flash_dynamic_mask_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + attention_bias=attn_bias, + scale=self.scaling, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights +``` + +这个示例展示了: +- **动态注意力偏置生成**: 使用可学习参数创建注意力偏置 +- **灵活的后端选择**: 通过 `attention_interface` 轻松切换注意力实现 +- **正确的张量重塑**: 根据需要在不同的张量布局之间转换 +- **与缓存的集成**: 在生成场景中支持键值缓存 + + +## 常见问题与解决方案 + +### 导入错误 + +```python +try: + from flash_dmattn import flash_dmattn_func_auto, get_available_backends + print("✅ 导入成功", get_available_backends()) +except ImportError as e: + print(f"❌ 导入失败: {e}") + print("请使用以下命令安装: pip install -e .") +``` + +### 性能问题 + +1. 执行缓慢:确保所有张量在同一个 GPU 上且最后一维是连续的;使用 8 的倍数的头维度;尽可能使用 CUDA 后端 +2. 高内存:使用梯度检查点;分块长序列;考虑对超长序列使用 Triton 或 Flex 后端 +3. 数值稳定性:优先使用 bfloat16;检查掩码/偏置是否有 NaN/Inf;监控梯度范数 + +### Transformers 集成问题 + +1. 模型兼容性:确保您的模型支持自定义注意力实现 +2. 形状不匹配:检查张量布局是否匹配预期格式 +3. 梯度流:验证梯度是否正确地通过自定义注意力函数流动 + +### 调试 + +```python +import torch +from flash_dmattn import flash_dmattn_func_auto + +torch.autograd.set_detect_anomaly(True) +attn = flash_dmattn_func_auto() +output = attn(q, k, v, attn_mask=attn_mask, attn_bias=attn_bias, is_causal=True) +if torch.isnan(output).any(): + print("⚠️ 注意力输出中检测到 NaN") +``` + +### 内存监控 + +```python +def print_memory_stats(): + if torch.cuda.is_available(): + print(f"已分配: {torch.cuda.memory_allocated() / 1e9:.2f} GB") + print(f"已预留: {torch.cuda.memory_reserved() / 1e9:.2f} GB") + print(f"最大分配: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") + +print_memory_stats() +attn = flash_dmattn_func_auto() +output = attn(q, k, v) +print_memory_stats() +``` From 7acd05136a600e06b3351a1e84e6e6d479a792a4 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Tue, 7 Oct 2025 20:37:04 +0800 Subject: [PATCH 06/15] Updates API docs with quick start and backend guide Reorganizes documentation for faster onboarding: adds Quick Start, backend selection/comparison, and a clearer API reference. Clarifies attention mask/bias shapes with broadcasting support and updates the integration example to drop manual expansion and eager fallback. Adds concise install instructions, backend availability flags, and unified import guidance; removes redundant backend and summary sections for clarity. --- docs/api_reference.md | 318 ++++++++++++++++++------------------------ 1 file changed, 137 insertions(+), 181 deletions(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index b99eea4..39585ff 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -5,66 +5,148 @@ Flash Dynamic Mask Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences. -Interfaces provided: -- High-level: simple entry point with automatic backend selection -- Backend-specific: direct access to CUDA, Triton, and Flex implementations -- Transformers Integration: seamless integration with HuggingFace Transformers models - ## Table of Contents 1. [Installation](#installation) -2. [High-Level Interface](#high-level-interface) -3. [Core Functions](#core-functions) -4. [Transformers Integration](#transformers-integration) -5. [Backend Selection](#backend-selection) +2. [Quick Start](#quick-start) +3. [Backend Selection and Comparison](#backend-selection-and-comparison) +4. [API Reference](#api-reference) + - [CUDA Backend: flash_dmattn_func](#flash_dmattn_func-cuda-backend) + - [Triton Backend: triton_dmattn_func](#triton_dmattn_func-triton-backend) + - [Flex Backend: flex_dmattn_func](#flex_dmattn_func-flex-backend) +5. [Integrations](#integrations) + - [Transformers Integration](#transformers-integration) 6. [Common Issues and Solutions](#common-issues-and-solutions) -7. [Summary](#summary) ## Installation -### Prerequisites +Please refer to the [README](https://github.com/SmallDoges/flash-dmattn/blob/main/README.md#install) for detailed installation instructions. -- Python: 3.8+ -- PyTorch: 2.0.0+ with CUDA -- CUDA: 11.8+ for CUDA backend -- NVIDIA GPU: Compute Capability 8.0+ for CUDA backend -- Optional: `triton` for Triton backend, `transformers` for Flex backend and integrations +```bash +# With CUDA backend +pip install flash-dmattn -### Install from Source +# Or install from source +pip install -e . -```bash -git clone https://github.com/SmallDoges/flash-dmattn.git -cd flash-dmattn -MAX_JOBS=4 pip install . --no-build-isolation +# Triton/Flex only +FLASH_DMATTN_SKIP_CUDA_BUILD=1 pip install -e . ``` -## High-Level Interface +## Quick Start + +Use `flash_dmattn_func_auto` to automatically select the best available backend without manual checking. + +```python +import torch +from flash_dmattn import flash_dmattn_func_auto + +# Prepare input tensors +batch, seqlen, num_heads, head_dim = 2, 1024, 8, 64 +q = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') +k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') +v = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda') + +# Get attention function (auto-select backend, priority: cuda > triton > flex) +attn_func = flash_dmattn_func_auto() + +# Compute attention +output = attn_func(q, k, v, is_causal=True) +print(f"Output shape: {output.shape}") # (2, 1024, 8, 64) + +# Or force a specific backend +attn_func = flash_dmattn_func_auto(backend="cuda") # or "triton", "flex" +output = attn_func(q, k, v, is_causal=True) +``` + +> [!NOTE] +> `flash_dmattn_func_auto` returns a callable attention function, not the attention output. -### Automatic Backend Selection -Note: `flash_dmattn_func_auto` returns a callable attention function, not the attention output. +## Backend Selection and Comparison + +### Check Available Backends ```python -from flash_dmattn import get_available_backends, flash_dmattn_func_auto +from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE -# Check available backends -backends = get_available_backends() -print(f"Available backends: {backends}") +# List all available backends +print(get_available_backends()) # e.g., ["cuda", "triton", "flex"] + +# Check specific backend availability +print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABLE}") +``` -# Auto-select (priority: cuda > triton > flex) -dmattn_func = flash_dmattn_func_auto() -output = dmattn_func(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=None) +### Backend Feature Comparison + +| Feature | CUDA | Triton | Flex | +|---------|------|--------|------| +| **Performance** | Highest | Good | Good | +| **Memory Efficiency** | Best | Good | Good | +| **Build Requirements** | Custom CUDA extension | triton package | transformers package | +| **GQA Support** | ✅ | ✅ | ✅ | +| **Attention Mask** | ✅ | ✅ | ⚠️ | +| **Attention Bias** | ✅ | ✅ | ✅ | +| **Causal Mask** | ✅ | ✅ | ✅ | +| **Softcap** | ✅ | ❌ | ❌ | +| **Deterministic** | ✅ | ❌ | ❌ | +| **Return Attention Probs** | ✅ | ❌ | ❌ | +| **Backward Support** | ✅ | ✅ | ⚠️ | + +> [!NOTE] +> ✅ Fully supported | ⚠️ Limited support | ❌ Not supported + +### When to Use Each Backend + +**CUDA Backend** ([details](#flash_dmattn_func-cuda-backend)) +- ✅ Training workloads requiring full gradient support +- ✅ Production inference requiring maximum performance +- ✅ Applications needing deterministic behavior +- ❌ Avoid: when custom CUDA extensions cannot be built -# Force a specific backend -dmattn_func = flash_dmattn_func_auto(backend="cuda") # or "triton", "flex" -output = dmattn_func(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=None) +**Triton Backend** ([details](#triton_dmattn_func-triton-backend)) +- ✅ Training when CUDA extension unavailable +- ✅ Development and prototyping +- ✅ Cross-platform compatibility needs +- ✅ Good balance of performance and ease of installation + +**Flex Backend** ([details](#flex_dmattn_func-flex-backend)) +- ✅ Inference-only applications +- ✅ Research with latest PyTorch features +- ✅ Quick experimentation without custom builds +- ❌ Avoid: training (limited backward support) +- ❌ Avoid: when strict attention mask compliance required + +### Import Available Functions + +```python +from flash_dmattn import ( + # Automatic backend selection + get_available_backends, + flash_dmattn_func_auto, + + # Backend-specific functions + flash_dmattn_func, # CUDA backend + triton_dmattn_func, # Triton backend + flex_dmattn_func, # Flex backend + + # Backend availability flags + CUDA_AVAILABLE, + TRITON_AVAILABLE, + FLEX_AVAILABLE, +) + +# Transformers integration +from flash_dmattn.integrations.flash_dynamic_mask_attention import ( + flash_dynamic_mask_attention_forward +) ``` -## Core Functions +## API Reference ### flash_dmattn_func (CUDA backend) @@ -75,8 +157,8 @@ def flash_dmattn_func( query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim) - attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) - attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + attn_mask: Optional[torch.Tensor] = None, # (batch, {num_heads, num_kv_heads, 1}, {seqlen_q, 0}, seqlen_k) + attn_bias: Optional[torch.Tensor] = None, # (batch, {num_heads, num_kv_heads, 1}, {seqlen_q, 0}, seqlen_k) scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) is_causal: Optional[bool] = None, # causal mask softcap: Optional[float] = None, # CUDA-only @@ -90,8 +172,8 @@ def flash_dmattn_func( - query: (B, Q, H, D). CUDA tensor, fp16/bf16, last dim contiguous - key: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H - value: (B, K, H_kv, D). Same dtype/device as query; GQA when H_kv <= H -- attn_mask: (B, H, Q, K). 1.0 = visible, 0.0 = masked. None to disable -- attn_bias: (B, H, Q, K). Added to scores before softmax. None to disable +- attn_mask: (B, {H, H_kv, 1}, {Q, 0}, K). 1.0 = visible, 0.0 = masked. None to disable +- attn_bias: (B, {H, H_kv, 1}, {Q, 0}, K). Added to scores before softmax. None to disable - scale: score scaling; default 1/sqrt(D) - is_causal: apply lower-triangular mask - softcap, deterministic, return_attn_probs: only effective on the CUDA backend; ignored on others @@ -133,11 +215,13 @@ def flex_dmattn_func( ``` -## Transformers Integration +## Integrations + +### Transformers Integration Integration function for HuggingFace Transformers models that provides seamless flash dynamic mask attention support. -### flash_dynamic_mask_attention_forward +#### flash_dynamic_mask_attention_forward ```python @@ -148,8 +232,8 @@ def flash_dynamic_mask_attention_forward( query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim) key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) value: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) - attention_mask: Optional[torch.Tensor], # (batch_size, num_kv_heads, query_len, key_len) - attention_bias: Optional[torch.Tensor], # (batch_size, num_kv_heads, query_len, key_len) + attention_mask: Optional[torch.Tensor], # (batch_size, {num_heads, num_kv_heads, 1}, {query_len, 0}, key_len) + attention_bias: Optional[torch.Tensor], # (batch_size, {num_heads, num_kv_heads, 1}, {query_len, 0}, key_len) scaling: Optional[float] = None, # score scaling softcap: Optional[float] = None, # softcap value **kwargs, @@ -162,8 +246,8 @@ def flash_dynamic_mask_attention_forward( - query: Query tensor with head-first layout (B, H, Q, D) - key: Key tensor with head-first layout (B, H_kv, K, D) - value: Value tensor with head-first layout (B, H_kv, K, D) -- attention_mask: Boolean attention mask -- attention_bias: Attention bias to add to scores +- attention_mask: Boolean attention mask (B, {H, H_kv, 1}, {Q, 0}, K) +- attention_bias: Attention bias to add to scores (B, {H, H_kv, 1}, {Q, 0}, K) - scaling: Score scaling factor - softcap: Softcap value for attention scores - **kwargs: Additional arguments including: @@ -176,7 +260,7 @@ def flash_dynamic_mask_attention_forward( - tuple[torch.Tensor, None]: Output tensor (B, Q, H, D) and None for compatibility -### Usage with Transformers +#### Usage Example ```python import torch @@ -207,7 +291,6 @@ class DynamicMaskAttention(nn.Module): self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) - # Dynamic mask for the QK^T attention weights matrix self.A = nn.Parameter(torch.zeros(config.num_key_value_heads)) self.dt_proj = nn.Linear( config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias @@ -238,27 +321,18 @@ class DynamicMaskAttention(nn.Module): query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache + # sin and cos are specific to RoPE models; static cache needs cache_position cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Sampling dt_states from value_states to generate attention bias + # Sample dt_states from value_states to generate attention_bias dt_states = self.dt_proj( value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) ) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attn_bias = dt_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[1], -1 - ).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len] - - # Choose attention implementation: fallback to eager if flash_dmattn is not available - attention_interface: Callable = eager_attention_forward - if flash_dynamic_mask_attention_forward is not None: - attention_interface = flash_dynamic_mask_attention_forward - - # Expand attention mask to match the expected shape - if attention_mask is not None: - attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) + attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) + + # Choose attention implementation + attention_interface: Callable = flash_dynamic_mask_attention_forward attn_output, attn_weights = attention_interface( self, @@ -277,92 +351,11 @@ class DynamicMaskAttention(nn.Module): This example shows: - **Dynamic attention bias generation**: Using learnable parameters to create attention bias -- **Flexible backend selection**: Graceful fallback to standard attention when flash_dmattn is unavailable +- **Flexible backend selection**: Easily switch attention implementations via `attention_interface` - **Proper tensor reshaping**: Converting between different tensor layouts as needed - **Integration with caching**: Support for key-value caching in generation scenarios -## Backend Selection - -### Available Backends - -```python -from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE - -print(get_available_backends()) # e.g., ["cuda", "triton", "flex"] -print(CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE) -``` - -### Available Functions - -The library exports the following functions: - -```python -from flash_dmattn import ( - # High-level interface - get_available_backends, # Get list of available backends - flash_dmattn_func_auto, # Automatic backend selection - - - # Backend-specific functions - flash_dmattn_func, # CUDA backend (if available) - triton_dmattn_func, # Triton backend (if available) - flex_dmattn_func, # Flex Attention backend (if available) - - # Backend availability flags - CUDA_AVAILABLE, - TRITON_AVAILABLE, - FLEX_AVAILABLE, -) - -# Transformers integration -from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward -``` - -### Backend-Specific Functions - -```python -# Direct access to specific backends -from flash_dmattn import flash_dmattn_func # CUDA backend -from flash_dmattn import triton_dmattn_func # Triton backend -from flash_dmattn import flex_dmattn_func # Flex Attention backend - -# Unified call signature (public layer) -# query/key/value: (B, L{q/k}, H, D) -# attn_mask/attn_bias: (B, H, Lq, Lk) -# is_causal: bool, scale: Optional[float] -output = flash_dmattn_func(q, k, v, attn_mask=mask, attn_bias=bias, is_causal=True, scale=None) -output = triton_dmattn_func(q, k, v, attn_mask=mask, attn_bias=bias, is_causal=True, scale=None) -output = flex_dmattn_func(q, k, v, attn_mask=mask, attn_bias=bias, is_causal=True, scale=None) -``` - -Notes: -- All backends support the same unified interface for seamless switching -- Flex backend currently uses causal masking and score_mod with bias; provided attn_mask is not applied in the kernel at the moment, subject to change in future versions -- CUDA backend supports additional parameters like softcap, deterministic, and return_attn_probs - -### When to Use Each Backend - -**CUDA Backend:** -- ✅ Training workloads requiring full gradient support -- ✅ Production inference requiring maximum performance -- ✅ Applications needing deterministic behavior -- ❌ Avoid if you cannot build custom CUDA extensions - -**Triton Backend:** -- ✅ Training workloads when CUDA extension is not available -- ✅ Development and prototyping -- ✅ Cross-platform compatibility needs -- ✅ Good balance of performance and ease of installation - -**Flex Backend:** -- ✅ Inference-only applications -- ✅ Research with latest PyTorch features -- ✅ Quick experimentation without custom builds -- ❌ Avoid for training due to limited backward support -- ❌ Avoid when strict attention mask compliance is required - - ## Common Issues and Solutions ### Import Errors @@ -416,40 +409,3 @@ output = attn(q, k, v) print_memory_stats() ``` -## Summary - -Flash Dynamic Mask Attention provides a unified interface for high-performance attention computation with the following key features: - -- **Multiple Backends**: CUDA for best performance, Triton for good compatibility, and Flex Attention for native PyTorch support -- **Automatic Backend Selection**: Seamless fallback between available backends -- **Dynamic Masking**: Efficient sparse attention with arbitrary attention masks -- **GQA Support**: Grouped-query attention for efficient inference -- **Transformers Integration**: Direct integration with HuggingFace models -- **Memory Efficiency**: Optimized memory usage for very long sequences - -Choose the backend that best fits your needs: -- **CUDA**: For maximum performance and full feature support, especially for training -- **Triton**: For good performance without custom CUDA compilation, supports both training and inference -- **Flex**: For inference scenarios and compatibility with latest PyTorch features, but limited backward support for training yet - -### Backend Comparison - -| Feature | CUDA | Triton | Flex | -|---------|------|--------|------| -| Performance | Highest | Good | Good | -| Memory Efficiency | Best | Good | Good | -| Build Requirements | Custom CUDA extension | triton package | transformers package | -| GQA Support | ✅ | ✅ | ✅ | -| Attention Mask | ✅ | ✅ | ⚠️ | -| Attention Bias | ✅ | ✅ | ✅ | -| Causal Mask | ✅ | ✅ | ✅ | -| Softcap | ✅ | ❌ | ❌ | -| Deterministic | ✅ | ❌ | ❌ | -| Return Attention Probs | ✅ | ❌ | ❌ | -| Backward Support | ✅ | ✅ | ⚠️ | - -Notes: -- ✅ = Fully supported -- ⚠️ = Limited support or workarounds needed -- ❌ = Not supported - From 7ad0d42289bea452e7dc6127035d0715427b99da Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 9 Oct 2025 18:11:28 +0800 Subject: [PATCH 07/15] Rename keep_window_size to window_size in DogeConfig and DogeAttention classes for consistency --- examples/modeling/configuration_doge.py | 6 +++--- examples/modeling/modeling_doge.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/modeling/configuration_doge.py b/examples/modeling/configuration_doge.py index 624312b..8a529a3 100644 --- a/examples/modeling/configuration_doge.py +++ b/examples/modeling/configuration_doge.py @@ -108,7 +108,7 @@ class DogeConfig(PretrainedConfig): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. sliding_window (`int`, *optional*): Sliding window attention window size. If not specified, will default to `None`. - keep_window_size (`int`, *optional*, defaults to 2048): + window_size (`int`, *optional*, defaults to 2048): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. is_moe (`bool`, *optional*, defaults to `False`): Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize. @@ -185,7 +185,7 @@ def __init__( attention_dropout=0.0, mlp_bias=False, sliding_window=None, - keep_window_size=2048, + window_size=2048, is_moe=False, num_experts=16384, num_experts_per_tok=64, @@ -214,7 +214,7 @@ def __init__( self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias self.sliding_window = sliding_window - self.keep_window_size = keep_window_size + self.window_size = window_size self.is_moe = is_moe self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index dd58c80..a4de840 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -165,7 +165,7 @@ def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.keep_window_size = config.keep_window_size + self.window_size = config.window_size self.is_causal = True self.q_proj = nn.Linear( From bbb5a7da00beea9f82635cbc86aca8543e9035d4 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 9 Oct 2025 18:11:56 +0800 Subject: [PATCH 08/15] Rename keep_window_size to window_size in flash_dynamic_mask_attention_forward for consistency --- flash_dmattn/integrations/flash_dynamic_mask_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index 16631aa..c37d367 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -36,7 +36,7 @@ def flash_dynamic_mask_attention_forward( **kwargs: Additional keyword arguments. Includes: - is_causal (bool): Whether to apply a causal mask. - - keep_window_size (int): The size of the window to keep. + - window_size (int): The size of the window to keep. - layer_idx (int): The index of the layer (for logging purposes). - implementation (str): The implementation to use ("flash_dmattn" or None). @@ -84,7 +84,7 @@ def flash_dynamic_mask_attention_forward( # FDMA always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice kwargs.pop("is_causal", None) - kwargs.pop("keep_window_size", None) + kwargs.pop("window_size", None) attn_output = _flash_dynamic_mask_attention_forward( query, @@ -97,7 +97,7 @@ def flash_dynamic_mask_attention_forward( is_causal=module.is_causal, softmax_scale=scaling, softcap=softcap, - keep_window_size=module.keep_window_size, + window_size=module.window_size, target_dtype=target_dtype, implementation="flash_dmattn", layer_idx=module.layer_idx if hasattr(module, "layer_idx") else None, From 4fe9561b95da50d1620810794887329cf6f5c730 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 9 Oct 2025 18:13:08 +0800 Subject: [PATCH 09/15] Refactor mask and bias parameters to be optional in flash_dmattn functions --- flash_dmattn/flash_dmattn_interface.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 2db8ea4..13385a3 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -4,7 +4,7 @@ from packaging import version import torch -import flash_dmattn_cuda as flash_dmattn_gpu # type: ignore +import flash_dmattn_cuda as flash_dmattn_gpu # type: ignore def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: @@ -75,8 +75,8 @@ def _flash_dmattn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], softmax_scale: float, is_causal: bool, softcap: float, @@ -104,8 +104,8 @@ def _flash_dmattn_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], softmax_scale: float, is_causal: bool, softcap: float, @@ -132,8 +132,8 @@ def _flash_dmattn_backward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], @@ -180,8 +180,8 @@ def _flash_dmattn_backward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - mask: torch.Tensor, - bias: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], @@ -314,7 +314,8 @@ def backward( ctx.deterministic, ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + # We could have padded the head dimension + dq = dq[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] From 060471fb33b9769a05dda6344dbc14b31dcd81e8 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 9 Oct 2025 18:15:09 +0800 Subject: [PATCH 10/15] Rename scale parameter to softmax_scale in flash_dmattn_func for clarity and consistency --- flash_dmattn/flash_dmattn_interface.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 13385a3..330dfb2 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -334,7 +334,7 @@ def flash_dmattn_func( value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, - scale: Optional[float] = None, + softmax_scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, @@ -368,18 +368,15 @@ def flash_dmattn_func( key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim) value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim) attn_mask: torch.Tensor, optional. The attention mask boolean tensor of - shape (batch_size, nheads, seqlen_q, seqlen_k) to apply to the attention scores. - Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or - (batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA. + shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0}, seqlen_k) to apply to the attention scores. If None, no mask is applied. attn_bias: torch.Tensor, optional. The attention bias float tensor of - shape (batch_size, nheads, seqlen_q, seqlen_k) to add to the attention scores. - Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or - (batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA. + shape (batch_size, {nheads|nheads_k|1}, {seqlen_q|0}, seqlen_k) to add to the attention scores. If None, no bias is applied. - is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - scale: float. The scaling of QK^T before applying softmax. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). + is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + softcap: float. Anything > 0 activates softcapping attention. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for @@ -399,7 +396,7 @@ def flash_dmattn_func( value, attn_mask, attn_bias, - scale, + softmax_scale, is_causal, softcap, deterministic, From ddc8cd5ff1274f1751242dbfca388b776fd4739e Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 9 Oct 2025 18:16:50 +0800 Subject: [PATCH 11/15] Adds varlen Flash-DM attention (CUDA) + autograd MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces CUDA custom ops and fake/meta registrations for variable-length attention forward/backward using cumulative sequence lengths, enabling efficient packed ragged batches. Adds an autograd wrapper and public API exposing varlen attention with support for MQA/GQA, optional mask/bias, causal mode, softcap, deterministic backward, and optional attention probs/LSE for testing. Pads head dim and key seqlen to multiples of 8 for 16‑bit–friendly allocations, rounds workspace shapes, and sanitizes outputs; also supports paged KV via a block table. Improves performance and memory by avoiding per-sequence padding and aligning allocations to hardware-friendly sizes. --- flash_dmattn/flash_dmattn_interface.py | 407 +++++++++++++++++++++++++ 1 file changed, 407 insertions(+) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 330dfb2..861184b 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -126,6 +126,89 @@ def _flash_dmattn_forward_fake( _wrapped_flash_dmattn_forward = _flash_dmattn_forward +@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_forward", mutates_args=(), device_types="cuda") +def _flash_dmattn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + softcap: float = 0.0, + return_softmax: bool = False, + block_table: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] + out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd( + q, + k, + v, + mask, + bias, + None, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + block_table, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + zero_tensors, + is_causal, + softcap, + return_softmax, + ) + _sanitize_tensors(out, nan=0.0, posinf=0.0, neginf=0.0) + return out, softmax_lse, S_dmask + + +@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_forward") +def _flash_dmattn_varlen_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + softcap: float = 0.0, + return_softmax: bool = False, + block_table: Optional[torch.Tensor] = None, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v, mask, bias = [maybe_contiguous(x) for x in (q, k, v, mask, bias)] + paged_kv = block_table is not None + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + out = torch.empty_like(q) + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) + p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) + seqlen_q_rounded = round_multiple(max_seqlen_q, 128) + seqlen_k_rounded = round_multiple(max_seqlen_k, 128) + if return_softmax: + p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + return out, softmax_lse, p + + +_wrapped_flash_dmattn_varlen_forward = _flash_dmattn_varlen_forward + + @_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") def _flash_dmattn_backward( dout: torch.Tensor, @@ -211,7 +294,109 @@ def _flash_dmattn_backward_fake( _wrapped_flash_dmattn_backward = _flash_dmattn_backward +@_torch_custom_op_wrapper("flash_dmattn::_flash_dmattn_varlen_backward", mutates_args=("dq", "dk", "dv", "dbias"), device_types="cuda") +def _flash_dmattn_varlen_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + dbias: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + softcap: float, + deterministic: bool, + zero_tensors: bool = False, +) -> torch.Tensor: + dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] + ( + dq, + dk, + dv, + dbias, + softmax_d, + ) = flash_dmattn_gpu.varlen_bwd( + dout, + q, + k, + v, + mask, + bias, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + zero_tensors, + is_causal, + softcap, + deterministic, + ) + _sanitize_tensors(dq, dk, dv, dbias, nan=0.0, posinf=0.0, neginf=0.0) + return softmax_d + + +@_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_backward") +def _flash_dmattn_varlen_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + dbias: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + softcap: float, + deterministic: bool, + zero_tensors: bool = False, +) -> torch.Tensor: + dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + if dbias is None and bias is not None: + dbias = torch.empty_like(bias) + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + + return softmax_d + + +_wrapped_flash_dmattn_varlen_backward = _flash_dmattn_varlen_backward + + class FlashDMAttnFunc(torch.autograd.Function): + @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, @@ -328,6 +513,137 @@ def backward( return dq, dk, dv, None, dbias, None, None, None, None, None, None +class FlashAttnVarlenFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: Optional[bool], + softcap: Optional[float], + deterministic: Optional[bool], + return_softmax: Optional[bool], + block_table: Optional[torch.Tensor], + is_grad_enabled: bool = True, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if is_causal is None: + is_causal = False + if softcap is None: + softcap = 0.0 + if deterministic is None: + deterministic = True + if return_softmax is None: + return_softmax = False + + # Padding to multiple of 8 for 16-bit memory allocations + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + seqlen_k_og = k.shape[1] + if seqlen_k_og % 8 != 0: + k = torch.nn.functional.pad(k, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + v = torch.nn.functional.pad(v, [0, 0, 0, 0, 0, 8 - seqlen_k_og % 8]) + if mask is not None: + mask = torch.nn.functional.pad(mask, [0, 8 - seqlen_k_og % 8], value=False) + if bias is not None: + bias = torch.nn.functional.pad(bias, [0, 8 - seqlen_k_og % 8], value=0.0) + + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( + q, + k, + v, + mask, + bias, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal=is_causal, + softcap=softcap, + return_softmax=return_softmax, + block_table=block_table, + ) + + if is_grad: + ctx.save_for_backward( + q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k + ) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.is_causal = is_causal + ctx.softcap = softcap + ctx.deterministic = deterministic + + out = out_padded[..., :head_size_og] + + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + dbias = torch.zeros_like(bias).contiguous() if bias is not None else None + + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + + _wrapped_flash_dmattn_varlen_backward( + dout_padded, + q, + k, + v, + mask, + bias, + out, + softmax_lse, + dq, + dk, + dv, + dbias, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.softmax_scale, + ctx.is_causal, + ctx.softcap, + ctx.deterministic, + ) + + # We could have padded the head dimension + dq = dq[..., : dout.shape[-1]] + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + + if ctx.seqlen_k_og % 8 != 0: + dk = dk[:, : ctx.seqlen_k_og, :, :] + dv = dv[:, : ctx.seqlen_k_og, :, :] + if dbias is not None: + dbias = dbias[..., : ctx.seqlen_k_og] + + return dq, dk, dv, None, dbias, None, None, None, None, None, None, None, None, None, None, None + + def flash_dmattn_func( query: torch.Tensor, key: torch.Tensor, @@ -403,3 +719,94 @@ def flash_dmattn_func( return_attn_probs, torch.is_grad_enabled(), ) + + +def flash_dmattn_varlen_func( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + is_causal: Optional[bool] = None, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + return_attn_probs: Optional[bool] = None, + block_table: Optional[torch.Tensor] = None, +): + """ + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Similarity, also supports attn_mask and attn_bias with head dimension of 1, nheads_k or nheads for MQA/GQA. + For example, if Q has 6 heads, K, V have 2 heads, then attn_mask and attn_bias can have head dimension + of 1, 2 or 6. If it is 1, all heads use the same mask/bias; if it is 2, head 0, 1, 2 of Q use head 0 + of mask/bias, head 3, 4, 5 of Q use head 1 of mask/bias. If it is 6, each head uses its own mask/bias. + + If is_causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + Arguments: + query: torch.Tensor. The query tensor of shape (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + key: torch.Tensor. The key tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + value: torch.Tensor. The value tensor of shape (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + attn_mask: torch.Tensor, optional. The attention mask boolean tensor of + shape (total_q, {nheads|nheads_k|1}, max_seqlen_k) or (total_k, {nheads|nheads_k|1}) to apply to the attention scores. + If None, no mask is applied. + attn_bias: torch.Tensor, optional. The attention bias float tensor of + shape (total_q, {nheads|nheads_k|1}, max_seqlen_k) or (total_k, {nheads|nheads_k|1}) to add to the attention scores. + If None, no bias is applied. + cu_seqlens_q: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into q. + cu_seqlens_k: torch.Tensor. The cumulative sequence lengths of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + softcap: float. Anything > 0 activates softcapping attention. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). + """ + return FlashAttnVarlenFunc.apply( + query, + key, + value, + attn_mask, + attn_bias, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal, + softcap, + deterministic, + return_attn_probs, + block_table, + torch.is_grad_enabled(), + ) From f696fbd8940f9d6d27c00bff648cd2dd38bff4e0 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 9 Oct 2025 18:18:32 +0800 Subject: [PATCH 12/15] Import flash_dmattn_varlen_func alongside flash_dmattn_func for backend availability --- flash_dmattn/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/__init__.py b/flash_dmattn/__init__.py index bfe6b8d..8269cb0 100644 --- a/flash_dmattn/__init__.py +++ b/flash_dmattn/__init__.py @@ -7,11 +7,11 @@ # Import CUDA functions when available try: - from flash_dmattn.flash_dmattn_interface import flash_dmattn_func + from flash_dmattn.flash_dmattn_interface import flash_dmattn_func, flash_dmattn_varlen_func CUDA_AVAILABLE = True except ImportError: CUDA_AVAILABLE = False - flash_dmattn_func = None + flash_dmattn_func, flash_dmattn_varlen_func = None, None # Import Triton functions when available try: @@ -89,6 +89,7 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs): "TRITON_AVAILABLE", "FLEX_AVAILABLE", "flash_dmattn_func", + "flash_dmattn_varlen_func", "triton_dmattn_func", "flex_dmattn_func", "get_available_backends", From 923c9172487b7c15bea3491b249645cbca49419a Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 9 Oct 2025 18:18:54 +0800 Subject: [PATCH 13/15] Rename keep_window_size to window_size for consistency in README files --- README.md | 6 +++--- README_zh.md | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 3f95838..13c80f7 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ import math # Setup batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64 -keep_window_size = 128 +window_size = 128 device = torch.device('cuda') dtype = torch.bfloat16 min_dtype = torch.finfo(dtype).min # dtype minimum value @@ -172,10 +172,10 @@ attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=d attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) # Generate sparse mask based on bias -if seq_len > keep_window_size: +if seq_len > window_size: # Select top-k most important keys for each query topk_values, topk_indices = torch.topk( - attention_bias, keep_window_size, dim=-1, + attention_bias, window_size, dim=-1, largest=True, sorted=False ) # Generate valid top-k mask diff --git a/README_zh.md b/README_zh.md index 9060a6f..484a1a8 100644 --- a/README_zh.md +++ b/README_zh.md @@ -157,7 +157,7 @@ import math # 设置 batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64 -keep_window_size = 128 +window_size = 128 device = torch.device('cuda') dtype = torch.bfloat16 min_dtype = torch.finfo(dtype).min # dtype 的最小值 @@ -172,10 +172,10 @@ attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=d attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) # 基于 bias 生成稀疏 mask -if seq_len > keep_window_size: +if seq_len > window_size: # 为每个查询选择 top-k 最重要的键 topk_values, topk_indices = torch.topk( - attention_bias, keep_window_size, dim=-1, + attention_bias, window_size, dim=-1, largest=True, sorted=False ) # 生成有效的 top-k mask From 32074fac4048e528752d92612858c44e6073fb83 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 9 Oct 2025 18:19:44 +0800 Subject: [PATCH 14/15] Rename keep_window_size to window_size for consistency in API documentation --- docs/api_reference.md | 4 ++-- docs/api_reference_zh.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/api_reference.md b/docs/api_reference.md index 39585ff..e73eb1a 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -252,7 +252,7 @@ def flash_dynamic_mask_attention_forward( - softcap: Softcap value for attention scores - **kwargs: Additional arguments including: - is_causal: Whether to apply causal mask - - keep_window_size: Size of window to keep + - window_size: Size of window to keep - layer_idx: Layer index for logging - implementation: Implementation to use ("flash_dmattn" or None) @@ -279,7 +279,7 @@ class DynamicMaskAttention(nn.Module): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.keep_window_size = config.keep_window_size + self.window_size = config.window_size self.is_causal = True self.q_proj = nn.Linear( diff --git a/docs/api_reference_zh.md b/docs/api_reference_zh.md index da640bc..efd308a 100644 --- a/docs/api_reference_zh.md +++ b/docs/api_reference_zh.md @@ -251,7 +251,7 @@ def flash_dynamic_mask_attention_forward( - softcap: 注意力分数的 softcap 值 - **kwargs: 额外参数,包括: - is_causal: 是否应用因果掩码 - - keep_window_size: 保持的窗口大小 + - window_size: 保持的窗口大小 - layer_idx: 用于日志的层索引 - implementation: 使用的实现("flash_dmattn" 或 None) @@ -278,7 +278,7 @@ class DynamicMaskAttention(nn.Module): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.keep_window_size = config.keep_window_size + self.window_size = config.window_size self.is_causal = True self.q_proj = nn.Linear( From e3ff84c5f861ba0f6368d8198a19d174dee0ba6b Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 10 Oct 2025 07:46:40 +0800 Subject: [PATCH 15/15] Refactor fdma_peft_integration_check and _flash_dynamic_mask_attention_forward for clarity and consistency; rename keep_window_size to window_size and enhance FlashDynamicMaskAttentionKwargs documentation. --- ...ling_flash_dynamic_mask_attention_utils.py | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index cae5b3e..eac04e3 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -17,16 +17,29 @@ from .import_utils import is_flash_dmattn_available from transformers.utils import logging -from transformers.integrations import flash_attention logger = logging.get_logger(__name__) -def fdma_peft_integration_check(q, k, v, bias, target_dtype: Optional[torch.dtype] = None): +def fdma_peft_integration_check( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bias: Optional[torch.Tensor], + target_dtype: Optional[torch.dtype] = None +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + """ if target_dtype and q.dtype == torch.float32: logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-dmattn compatibility.") - q, k, v, bias = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype), bias.to(target_dtype) + q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype) + if bias is not None: + bias = bias.to(target_dtype) return q, k, v, bias @@ -43,8 +56,24 @@ def _lazy_imports(impl: Optional[str]): class FlashDynamicMaskAttentionKwargs(TypedDict, total=False): - cumulative_seqlens_q: Optional[torch.LongTensor] - cumulative_seqlens_k: Optional[torch.LongTensor] + """ + Keyword arguments for Flash Dynamic Mask Attention with Compile. + + Attributes: + cu_seq_lens_q (`torch.LongTensor`, *optional*) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`torch.LongTensor`, *optional*) + Gets cumulative sequence length for key state. + max_length_q (`int`, *optional*): + Maximum sequence length for query state. + max_length_k (`int`, *optional*): + Maximum sequence length for key state. + """ + + cu_seq_lens_q: Optional[torch.LongTensor] + cu_seq_lens_k: Optional[torch.LongTensor] + max_length_q: Optional[int] + max_length_k: Optional[int] def _flash_dynamic_mask_attention_forward( @@ -58,7 +87,7 @@ def _flash_dynamic_mask_attention_forward( is_causal: bool, softmax_scale: Optional[float] = None, softcap: Optional[float] = None, - keep_window_size: Optional[int] = None, + window_size: Optional[int] = None, deterministic: Optional[bool] = None, target_dtype: Optional[torch.dtype] = None, implementation: Optional[str] = None, @@ -66,7 +95,6 @@ def _flash_dynamic_mask_attention_forward( ): dtype = query_states.dtype min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape if not all(k in globals() for k in ("_flash_fn")): flash_fn = _lazy_imports(implementation) @@ -93,14 +121,12 @@ def _flash_dynamic_mask_attention_forward( min_dtype ) - if keep_window_size is not None and key_length > keep_window_size: + if window_size is not None and key_length > window_size: topk_values, topk_indices = torch.topk( - attention_bias, keep_window_size, dim=-1, largest=True, sorted=False + attention_bias, window_size, dim=-1, largest=True, sorted=False ) attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) - else: - attention_mask = None out = flash_fn( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal