Skip to content

Commit 8ed228d

Browse files
committed
Adds mask and bias template parameters to attention kernels
Extends kernel templates with Has_mask and Has_bias boolean parameters to support attention masking and bias operations. Updates all affected function signatures and call sites to maintain consistency across the attention computation pipeline.
1 parent da8d7ed commit 8ed228d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

csrc/flash_dmattn/src/flash_fwd_kernel.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ __forceinline__ __device__ auto get_lse_tile(
5050
return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
5151
}
5252

53-
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
53+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
5454
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
5555

5656
using Element = typename Kernel_traits::Element;
@@ -762,7 +762,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
762762

763763
////////////////////////////////////////////////////////////////////////////////////////////////////
764764

765-
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, typename Params>
765+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, typename Params>
766766
inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
767767

768768
using Element = typename Kernel_traits::Element;
@@ -1492,20 +1492,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
14921492

14931493
////////////////////////////////////////////////////////////////////////////////////////////////////
14941494

1495-
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
1495+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
14961496
inline __device__ void compute_attn(const Params &params) {
14971497
const int m_block = blockIdx.x;
14981498
// The block index for the batch.
14991499
const int bidb = blockIdx.y;
15001500
// The block index for the head.
15011501
const int bidh = blockIdx.z;
15021502

1503-
FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
1503+
FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
15041504
}
15051505

15061506
////////////////////////////////////////////////////////////////////////////////////////////////////
15071507

1508-
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, typename Params>
1508+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, typename Params>
15091509
inline __device__ void compute_attn_splitkv(const Params &params) {
15101510
const int m_block = blockIdx.x;
15111511
// The block index for the batch.
@@ -1514,7 +1514,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
15141514
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
15151515
const int n_split_idx = Split ? blockIdx.y : 0;
15161516
const int num_n_splits = Split ? gridDim.y : 1;
1517-
FLASH_NAMESPACE::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Split>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
1517+
FLASH_NAMESPACE::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap, Split>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
15181518
}
15191519

15201520
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)