Skip to content

Commit da8d7ed

Browse files
committed
Adds mask and bias template parameters to backward kernels
Extends the template parameter list to include Has_mask and Has_bias flags for better flexibility in handling attention mechanisms with masks and biases. Updates all function calls to pass through the new template parameters while maintaining backward compatibility with existing functionality.
1 parent c8be594 commit da8d7ed

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ CUTE_HOST_DEVICE auto make_tiled_copy_C_warpcontiguousN(
7676

7777
////////////////////////////////////////////////////////////////////////////////////////////////////
7878

79-
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
79+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
8080
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
8181

8282
using Element = typename Kernel_traits::Element;
@@ -1069,7 +1069,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
10691069

10701070
////////////////////////////////////////////////////////////////////////////////////////////////////
10711071

1072-
template<typename Kernel_traits, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
1072+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K, typename Params>
10731073
inline __device__ void compute_dq_dk_dv(const Params &params) {
10741074

10751075
// The block index for the batch.
@@ -1083,20 +1083,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
10831083

10841084
const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
10851085
if (n_block_max == 1) {
1086-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, true, true>(params, bidb, bidh, 0);
1086+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false, true, true>(params, bidb, bidh, 0);
10871087
} else {
10881088
// Iterating backward from n_block_max - 1 to 0 might save 1 register
1089-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, true, false>(params, bidb, bidh, n_block_max - 1);
1089+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false, true, false>(params, bidb, bidh, n_block_max - 1);
10901090
for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
1091-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, false, false>(params, bidb, bidh, n_block);
1091+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false, false, false>(params, bidb, bidh, n_block);
10921092
}
1093-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_M, Is_even_K, false, false, true>(params, bidb, bidh, 0);
1093+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K, false, false, true>(params, bidb, bidh, 0);
10941094
}
10951095
}
10961096

10971097
////////////////////////////////////////////////////////////////////////////////////////////////////
10981098

1099-
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
1099+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
11001100
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
11011101

11021102
// The block index for the batch.
@@ -1106,7 +1106,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
11061106

11071107
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
11081108
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
1109-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
1109+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
11101110
}
11111111
}
11121112

0 commit comments

Comments
 (0)