@@ -76,7 +76,7 @@ CUTE_HOST_DEVICE auto make_tiled_copy_C_warpcontiguousN(
7676
7777// //////////////////////////////////////////////////////////////////////////////////////////////////
7878
79- template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false , typename Params>
79+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false , typename Params>
8080inline __device__ void compute_dq_dk_dv_1colblock (const Params ¶ms, const int bidb, const int bidh, const int n_block) {
8181
8282 using Element = typename Kernel_traits::Element;
@@ -1069,7 +1069,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
10691069
10701070// //////////////////////////////////////////////////////////////////////////////////////////////////
10711071
1072- template <typename Kernel_traits, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
1072+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K, typename Params>
10731073inline __device__ void compute_dq_dk_dv (const Params ¶ms) {
10741074
10751075 // The block index for the batch.
@@ -1083,20 +1083,20 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
10831083
10841084 const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1 ) / Kernel_traits::kBlockN ;
10851085 if (n_block_max == 1 ) {
1086- compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false , true , true >(params, bidb, bidh, 0 );
1086+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false , true , true >(params, bidb, bidh, 0 );
10871087 } else {
10881088 // Iterating backward from n_block_max - 1 to 0 might save 1 register
1089- compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false , true , false >(params, bidb, bidh, n_block_max - 1 );
1089+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false , true , false >(params, bidb, bidh, n_block_max - 1 );
10901090 for (int n_block = n_block_max - 2 ; n_block > 0 ; n_block--) {
1091- compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false , false , false >(params, bidb, bidh, n_block);
1091+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false , false , false >(params, bidb, bidh, n_block);
10921092 }
1093- compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false , false , true >(params, bidb, bidh, 0 );
1093+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false , false , true >(params, bidb, bidh, 0 );
10941094 }
10951095}
10961096
10971097// //////////////////////////////////////////////////////////////////////////////////////////////////
10981098
1099- template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
1099+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
11001100inline __device__ void compute_dq_dk_dv_seqk_parallel (const Params ¶ms) {
11011101
11021102 // The block index for the batch.
@@ -1106,7 +1106,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
11061106
11071107 // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
11081108 for (int n_block = blockIdx.x ; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1 ) / Kernel_traits::kBlockN ; n_block += gridDim.x ) {
1109- compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, false , false , /* Seq_parallel=*/ true >(params, bidb, bidh, n_block);
1109+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap, false , false , /* Seq_parallel=*/ true >(params, bidb, bidh, n_block);
11101110 }
11111111}
11121112
0 commit comments