Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c04b8e9
Adds precomputed ratio fields for mask and bias heads
algo-home Sep 17, 2025
f613ff2
Reorganizes parameter assignments and adds ratio calculations
algo-home Sep 17, 2025
b54aeb7
Simplifies head index calculations for mask and bias
algo-home Sep 17, 2025
7b0e395
Simplifies head index calculation for mask and bias
algo-home Sep 17, 2025
851c643
Adds boolean flags to track mask and bias presence
algo-home Sep 17, 2025
26d8699
Updates mask and bias parameter comments to reflect correct head counts
algo-home Sep 18, 2025
58c9cf0
Adds conditional mask and bias support to flash attention
algo-home Sep 18, 2025
a1c52b9
Enhances dgrad parameter handling by adding has_mask and has_bias fla…
algo-home Sep 18, 2025
2dc07f1
Makes mask and bias parameters optional in MHA functions
algo-home Sep 18, 2025
8c9e21f
Makes mask and bias parameters optional in flash attention
algo-home Sep 18, 2025
443792e
Extends Flash attention to support mask and bias parameters
algo-home Sep 18, 2025
8a6e78c
Adds mask and bias support to kernel generation
algo-home Sep 18, 2025
b3b6f80
Adds mask and bias support to MHA function templates
algo-home Sep 18, 2025
aa36f57
Adds mask and bias support to attention kernels
algo-home Sep 18, 2025
1e51290
Adds compile-time masking and bias optimizations
algo-home Sep 18, 2025
fdda8b5
Adds mask and bias support to flash attention kernels
algo-home Sep 18, 2025
c8be594
Adds mask and bias support to flash attention backward kernels
algo-home Sep 18, 2025
da8d7ed
Adds mask and bias template parameters to backward kernels
algo-home Sep 18, 2025
8ed228d
Adds mask and bias template parameters to attention kernels
algo-home Sep 18, 2025
f391432
Adds support for 3D mask and bias tensors in attention
algo-home Sep 18, 2025
b919c43
Refactors tensor partitioning for mask and bias in attention computat…
algo-home Sep 19, 2025
266f5e6
Adds optional mask and bias support to kernel traits
algo-home Sep 19, 2025
e0b3d30
Optimizes block size based on attention parameters
algo-home Sep 19, 2025
7860c26
Optimizes kernel dispatch based on mask and bias flags
algo-home Sep 19, 2025
f00d1ff
Adds conditional compilation for mask and bias support
algo-home Sep 19, 2025
6ef3de8
Optimizes backward kernel with compile-time checks for mask and bias
algo-home Sep 19, 2025
3884cfe
Adds const qualifiers to mask and bias parameters
algo-home Sep 19, 2025
ca66b6c
Fixes bias gradient handling for 3D bias tensors
algo-home Sep 19, 2025
a61db8c
Refines documentation for attention_mask and attention_bias parameter…
algo-home Sep 19, 2025
87ce7cc
Removes default initialization of attention_bias in _flash_dynamic_ma…
algo-home Sep 19, 2025
77edcb0
Fixes bias tensor initialization in FlashDMAttnFunc to handle None case
algo-home Sep 19, 2025
96d6da0
Refactors CUDA extension sources in setup.py to use glob for dynamic …
algo-home Sep 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
382 changes: 235 additions & 147 deletions csrc/flash_dmattn/flash_api.cpp

Large diffs are not rendered by default.

23 changes: 15 additions & 8 deletions csrc/flash_dmattn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ struct QKV_params {
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
int h_h_k_ratio; // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Mask_params {
void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len]
void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_mask_heads, query_len, key_len]

// The stride of the attention mask tensors.
index_t mask_batch_stride; // Stride between batches of attention mask
Expand All @@ -53,12 +53,15 @@ struct Mask_params {

// The number of heads in the mask.
int h_mask;
int h_h_mask_ratio; // precompute h / h_mask

bool has_mask;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Bias_params {
void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len]
void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_bias_heads, query_len, key_len]

// The stride of the attention bias tensor.
index_t bias_batch_stride; // Stride between batches of attention bias
Expand All @@ -67,13 +70,16 @@ struct Bias_params {

// The number of heads in the bias.
int h_bias;
int h_h_bias_ratio; // precompute h / h_bias

bool has_bias;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params {

// The O matrix (output).
// The O matrix.
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;

Expand All @@ -90,7 +96,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
void * __restrict__ softmax_lseaccum_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, total_q, total_k;

// The scaling factors for the kernel.
float scale_softmax;
Expand All @@ -105,6 +111,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;

// TODO: block mask for less memory usage
int *__restrict__ blockmask;

// The K_new and V_new matrices.
Expand Down Expand Up @@ -192,9 +199,9 @@ struct Flash_bwd_params : public Flash_fwd_params {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

} // namespace FLASH_NAMESPACE
22 changes: 10 additions & 12 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ CUTE_HOST_DEVICE auto make_tiled_copy_C_warpcontiguousN(

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

using Element = typename Kernel_traits::Element;
Expand Down Expand Up @@ -107,12 +107,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)
+ h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
+ (bidh / params.h_h_mask_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)
+ h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
+ (bidh / params.h_h_bias_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb)
+ bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
Expand Down Expand Up @@ -1071,7 +1069,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv(const Params &params) {

// The block index for the batch.
Expand All @@ -1085,20 +1083,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {

const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (n_block_max == 1) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, true, true>(params, bidb, bidh, 0);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false, true, true>(params, bidb, bidh, 0);
} else {
// Iterating backward from n_block_max - 1 to 0 might save 1 register
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, true, false>(params, bidb, bidh, n_block_max - 1);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false, true, false>(params, bidb, bidh, n_block_max - 1);
for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, false, false>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false, false, false>(params, bidb, bidh, n_block);
}
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, false, true>(params, bidb, bidh, 0);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false, false, true>(params, bidb, bidh, 0);
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

// The block index for the batch.
Expand All @@ -1108,7 +1106,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
}

Expand Down
Loading