@@ -50,7 +50,7 @@ __forceinline__ __device__ auto get_lse_tile(
5050 return local_tile (mLSE_slice , Shape<Int<kBlockM >>{}, make_coord (m_block));
5151}
5252
53- template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
53+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
5454inline __device__ void compute_attn_1rowblock (const Params ¶ms, const int bidb, const int bidh, const int m_block) {
5555
5656 using Element = typename Kernel_traits::Element;
@@ -762,7 +762,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
762762
763763// //////////////////////////////////////////////////////////////////////////////////////////////////
764764
765- template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, typename Params>
765+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, typename Params>
766766inline __device__ void compute_attn_1rowblock_splitkv (const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
767767
768768 using Element = typename Kernel_traits::Element;
@@ -1492,20 +1492,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
14921492
14931493// //////////////////////////////////////////////////////////////////////////////////////////////////
14941494
1495- template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
1495+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
14961496inline __device__ void compute_attn (const Params ¶ms) {
14971497 const int m_block = blockIdx.x ;
14981498 // The block index for the batch.
14991499 const int bidb = blockIdx.y ;
15001500 // The block index for the head.
15011501 const int bidh = blockIdx.z ;
15021502
1503- FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
1503+ FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
15041504}
15051505
15061506// //////////////////////////////////////////////////////////////////////////////////////////////////
15071507
1508- template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, typename Params>
1508+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, typename Params>
15091509inline __device__ void compute_attn_splitkv (const Params ¶ms) {
15101510 const int m_block = blockIdx.x ;
15111511 // The block index for the batch.
@@ -1514,7 +1514,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
15141514 const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z ;
15151515 const int n_split_idx = Split ? blockIdx.y : 0 ;
15161516 const int num_n_splits = Split ? gridDim.y : 1 ;
1517- FLASH_NAMESPACE::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Split>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
1517+ FLASH_NAMESPACE::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap, Split>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
15181518}
15191519
15201520// //////////////////////////////////////////////////////////////////////////////////////////////////
0 commit comments