@@ -136,6 +136,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
136136 // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
137137
138138 const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM ) * params.seqlen_k_rounded + (n_block_max - 1 ) * kBlockN ;
139+ const int h_idx_mask = (params.h_mask == 1 ) ? 0 : ((params.h_mask == params.h_k ) ? (bidh / params.h_h_k_ratio ) : bidh);
140+ const int h_idx_bias = (params.h_bias == 1 ) ? 0 : ((params.h_bias == params.h_k ) ? (bidh / params.h_h_k_ratio ) : bidh);
139141
140142 // Global memory tensor configuration
141143 Tensor mQ = make_tensor (
@@ -170,21 +172,21 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
170172 ); // (kBlockN, kHeadDim, nblocksN)
171173 Tensor mMask = make_tensor (
172174 make_gmem_ptr (reinterpret_cast <const bool *>(params.mask_ptr ) + binfo.mask_offset (params.mask_batch_stride , params.mask_row_stride , bidb)),
173- make_shape (params.h_k , binfo.actual_seqlen_q , binfo.actual_seqlen_k ),
175+ make_shape (params.h_mask , binfo.actual_seqlen_q , binfo.actual_seqlen_k ),
174176 make_stride (params.mask_head_stride , params.mask_row_stride , _1{})
175177 );
176178 Tensor gMask = local_tile (
177- mMask (bidh / params. h_h_k_ratio , _, _),
179+ mMask (h_idx_mask , _, _),
178180 Shape<Int<kBlockM >, Int<kBlockN >>{},
179181 make_coord (m_block, _)
180182 ); // (kBlockM, kBlockN, nblocksN)
181183 Tensor mBias = make_tensor (
182184 make_gmem_ptr (reinterpret_cast <Element*>(params.bias_ptr ) + binfo.bias_offset (params.bias_batch_stride , params.bias_row_stride , bidb)),
183- make_shape (params.h_k , binfo.actual_seqlen_q , binfo.actual_seqlen_k ),
185+ make_shape (params.h_bias , binfo.actual_seqlen_q , binfo.actual_seqlen_k ),
184186 make_stride (params.bias_head_stride , params.bias_row_stride , _1{})
185187 );
186188 Tensor gBias = local_tile (
187- mBias (bidh / params. h_h_k_ratio , _, _),
189+ mBias (h_idx_bias , _, _),
188190 Shape<Int<kBlockM >, Int<kBlockN >>{},
189191 make_coord (m_block, _)
190192 ); // (kBlockM, kBlockN, nblocksN)
@@ -840,16 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
840842 ? binfo.k_offset (params.v_batch_stride , params.v_row_stride , bidb_cache)
841843 + (n_block_max - 1 ) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio ) * params.v_head_stride
842844 : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio ) * params.v_head_stride ;
845+ const int h_idx_mask = (params.h_mask == 1 ) ? 0 : ((params.h_mask == params.h_k ) ? (bidh / params.h_h_k_ratio ) : bidh);
843846 const index_t col_offset_mask = (block_table == nullptr )
844847 ? binfo.mask_offset (params.mask_batch_stride , params.mask_row_stride , bidb_cache)
845- + (bidh / params. h_h_k_ratio ) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1 ) * kBlockN
848+ + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1 ) * kBlockN
846849 : binfo.q_offset (/* batch_stride=*/ index_t (0 ), params.mask_row_stride , bidb_cache)
847- + (bidh / params.h_h_k_ratio ) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
850+ + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
851+ const int h_idx_bias = (params.h_bias == 1 ) ? 0 : ((params.h_bias == params.h_k ) ? (bidh / params.h_h_k_ratio ) : bidh);
848852 const index_t col_offset_bias = (block_table == nullptr )
849853 ? binfo.bias_offset (params.bias_batch_stride , params.bias_row_stride , bidb_cache)
850- + (bidh / params. h_h_k_ratio ) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1 ) * kBlockN
854+ + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1 ) * kBlockN
851855 : binfo.q_offset (/* batch_stride=*/ index_t (0 ), params.bias_row_stride , bidb_cache)
852- + (bidh / params. h_h_k_ratio ) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;
856+ + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;
853857
854858 // Global memory tensor configuration
855859 Tensor mQ = make_tensor (
0 commit comments