Skip to content

Commit e854654

Browse files
authored
Merge pull request #167 from SmallDoges/support-all-shape-of-mask/bias
[FEATURE SUPPORT] Flexible head dims for mask/bias with in-kernel conversion path
2 parents 1cf1385 + ccfd3ec commit e854654

File tree

7 files changed

+573
-475
lines changed

7 files changed

+573
-475
lines changed

benchmarks/backward_equivalence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def dynamic_mask_attention_python(
191191
value_states = repeat_kv(value_states, num_queries_per_kv)
192192
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
193193
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)
194-
194+
195195
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1))
196196
attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh
197197
attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization

csrc/flash_dmattn/flash_api.cpp

Lines changed: 542 additions & 460 deletions
Large diffs are not rendered by default.

csrc/flash_dmattn/src/flash.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ struct Mask_params {
5050
index_t mask_batch_stride; // Stride between batches of attention mask
5151
index_t mask_head_stride; // Stride between heads of attention mask
5252
index_t mask_row_stride; // Stride between rows of attention mask
53+
54+
// The number of heads in the mask.
55+
int h_mask;
5356
};
5457

5558
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -61,6 +64,9 @@ struct Bias_params {
6164
index_t bias_batch_stride; // Stride between batches of attention bias
6265
index_t bias_head_stride; // Stride between heads of attention bias
6366
index_t bias_row_stride; // Stride between rows of attention bias
67+
68+
// The number of heads in the bias.
69+
int h_bias;
6470
};
6571

6672
////////////////////////////////////////////////////////////////////////////////////////////////////

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
107107
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
108108
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
109109
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
110+
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
110111
const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)
111-
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
112+
+ h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
113+
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
112114
const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)
113-
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
115+
+ h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
114116
const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb)
115117
+ bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN;
116118
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)

csrc/flash_dmattn/src/flash_fwd_kernel.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
136136
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
137137

138138
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
139+
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
140+
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
139141

140142
// Global memory tensor configuration
141143
Tensor mQ = make_tensor(
@@ -170,21 +172,21 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
170172
); // (kBlockN, kHeadDim, nblocksN)
171173
Tensor mMask = make_tensor(
172174
make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
173-
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
175+
make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
174176
make_stride(params.mask_head_stride, params.mask_row_stride, _1{})
175177
);
176178
Tensor gMask = local_tile(
177-
mMask(bidh / params.h_h_k_ratio, _, _),
179+
mMask(h_idx_mask, _, _),
178180
Shape<Int<kBlockM>, Int<kBlockN>>{},
179181
make_coord(m_block, _)
180182
); // (kBlockM, kBlockN, nblocksN)
181183
Tensor mBias = make_tensor(
182184
make_gmem_ptr(reinterpret_cast<Element*>(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)),
183-
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
185+
make_shape(params.h_bias, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
184186
make_stride(params.bias_head_stride, params.bias_row_stride, _1{})
185187
);
186188
Tensor gBias = local_tile(
187-
mBias(bidh / params.h_h_k_ratio, _, _),
189+
mBias(h_idx_bias, _, _),
188190
Shape<Int<kBlockM>, Int<kBlockN>>{},
189191
make_coord(m_block, _)
190192
); // (kBlockM, kBlockN, nblocksN)
@@ -840,16 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
840842
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
841843
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
842844
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
845+
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
843846
const index_t col_offset_mask = (block_table == nullptr)
844847
? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache)
845-
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN
848+
+ h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN
846849
: binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache)
847-
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
850+
+ h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
851+
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
848852
const index_t col_offset_bias = (block_table == nullptr)
849853
? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache)
850-
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN
854+
+ h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN
851855
: binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache)
852-
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;
856+
+ h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;
853857

854858
// Global memory tensor configuration
855859
Tensor mQ = make_tensor(

flash_dmattn/flash_dmattn_interface.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,14 @@ def flash_dmattn_func(
361361
key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim)
362362
value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim)
363363
attn_mask: torch.Tensor, optional. The attention mask boolean tensor of
364-
shape (batch_size, nheads_k, seqlen_q, seqlen_k) to apply to the attention scores.
364+
shape (batch_size, nheads, seqlen_q, seqlen_k) to apply to the attention scores.
365+
Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or
366+
(batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA.
365367
If None, no mask is applied.
366368
attn_bias: torch.Tensor, optional. The attention bias float tensor of
367-
shape (batch_size, nheads_k, seqlen_q, seqlen_k) to add to the attention scores.
369+
shape (batch_size, nheads, seqlen_q, seqlen_k) to add to the attention scores.
370+
Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or
371+
(batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA.
368372
If None, no bias is applied.
369373
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
370374
scale: float. The scaling of QK^T before applying softmax.

flash_dmattn/integrations/flash_dynamic_mask_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward(
2929
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
3030
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3131
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
32-
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len).
33-
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len).
32+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
33+
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
3434
scaling (Optional[float]): The scaling factor for the attention scores.
3535
softcap (Optional[float]): The softcap value for the attention scores.
3636
**kwargs: Additional keyword arguments.

0 commit comments

Comments
 (0)