Skip to content

Commit 2239437

Browse files
committed
Improves mask and bias head indexing flexibility
Introduces dynamic head index calculation for mask and bias tensors to support different head configurations. Previously used fixed head ratio calculations, now supports three scenarios: - Single head broadcasting (h_mask/h_bias == 1) - Multi-head with ratio-based indexing (h_mask/h_bias == h_k) - Direct head indexing (fallback case) Enables more flexible attention masking and bias application across different multi-head attention configurations.
1 parent d18cdac commit 2239437

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

csrc/flash_dmattn/src/flash_fwd_kernel.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
136136
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
137137

138138
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
139+
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
140+
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
139141

140142
// Global memory tensor configuration
141143
Tensor mQ = make_tensor(
@@ -170,21 +172,21 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
170172
); // (kBlockN, kHeadDim, nblocksN)
171173
Tensor mMask = make_tensor(
172174
make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
173-
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
175+
make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
174176
make_stride(params.mask_head_stride, params.mask_row_stride, _1{})
175177
);
176178
Tensor gMask = local_tile(
177-
mMask(bidh / params.h_h_k_ratio, _, _),
179+
mMask(h_idx_mask, _, _),
178180
Shape<Int<kBlockM>, Int<kBlockN>>{},
179181
make_coord(m_block, _)
180182
); // (kBlockM, kBlockN, nblocksN)
181183
Tensor mBias = make_tensor(
182184
make_gmem_ptr(reinterpret_cast<Element*>(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)),
183-
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
185+
make_shape(params.h_bias, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
184186
make_stride(params.bias_head_stride, params.bias_row_stride, _1{})
185187
);
186188
Tensor gBias = local_tile(
187-
mBias(bidh / params.h_h_k_ratio, _, _),
189+
mBias(h_idx_bias, _, _),
188190
Shape<Int<kBlockM>, Int<kBlockN>>{},
189191
make_coord(m_block, _)
190192
); // (kBlockM, kBlockN, nblocksN)
@@ -840,16 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
840842
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
841843
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
842844
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
845+
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
843846
const index_t col_offset_mask = (block_table == nullptr)
844847
? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache)
845-
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN
848+
+ h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN
846849
: binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache)
847-
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
850+
+ h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
851+
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
848852
const index_t col_offset_bias = (block_table == nullptr)
849853
? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache)
850-
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN
854+
+ h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN
851855
: binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache)
852-
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;
856+
+ h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;
853857

854858
// Global memory tensor configuration
855859
Tensor mQ = make_tensor(

0 commit comments

Comments
 (0)