Skip to content

Commit 3bc89c8

Browse files
committed
Adds flexible head indexing for mask and bias tensors
Introduces conditional head index calculation for mask and bias operations based on tensor dimensions. Supports scenarios where mask/bias tensors can have single head (h=1), match key heads (h=h_k), or match query heads (h=h_q). Replaces hardcoded head index division with dynamic selection logic that adapts to different tensor head configurations in flash attention backward kernel.
1 parent 2239437 commit 3bc89c8

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
107107
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
108108
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
109109
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
110+
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
110111
const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)
111-
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
112+
+ h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
113+
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
112114
const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)
113-
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
115+
+ h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
114116
const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb)
115117
+ bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN;
116118
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)

0 commit comments

Comments
 (0)