Skip to content

Commit c8be594

Browse files
committed
Adds mask and bias support to flash attention backward kernels
Extends backward kernel templates with Has_mask and Has_bias parameters to enable attention masking and bias functionality during gradient computation. Updates all kernel instantiations and function signatures to propagate the new template parameters through the call chain, maintaining consistency across different head dimensions and device configurations. Includes minor code formatting improvements for better readability.
1 parent fdda8b5 commit c8be594

File tree

1 file changed

+44
-41
lines changed

1 file changed

+44
-41
lines changed

csrc/flash_dmattn/src/flash_bwd_launch_template.h

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,17 @@ namespace FLASH_NAMESPACE {
3131
template<typename Kernel_traits, __VA_ARGS__> \
3232
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
3333

34-
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Is_even_M, bool Is_even_K) {
34+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K) {
3535
#if defined(ARCH_SUPPORTS_FLASH)
36-
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Is_even_M, Is_even_K>(params);
36+
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K>(params);
3737
#else
3838
FLASH_UNSUPPORTED_ARCH
3939
#endif
4040
}
4141

42-
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
42+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
4343
#if defined(ARCH_SUPPORTS_FLASH)
44-
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap>(params);
44+
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap>(params);
4545
#else
4646
FLASH_UNSUPPORTED_ARCH
4747
#endif
@@ -68,7 +68,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
6868
FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params);
6969
}
7070

71-
template<typename Kernel_traits, bool Is_causal>
71+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias>
7272
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
7373
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
7474
dim3 grid_m(num_m_block, params.b, params.h);
@@ -98,11 +98,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
9898
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
9999
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
100100
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
101-
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
102-
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
103-
if (smem_size_dq_dk_dv >= 48 * 1024) {
104-
C10_CUDA_CHECK(cudaFuncSetAttribute(
105-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
101+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, Has_mask, Has_bias, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
102+
if (smem_size_dq_dk_dv >= 48 * 1024) {
103+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
106104
}
107105
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
108106
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -112,146 +110,151 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
112110

113111
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
114112
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
115-
C10_CUDA_CHECK(cudaFuncSetAttribute(
116-
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
113+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
117114
}
118115
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
119116
C10_CUDA_KERNEL_LAUNCH_CHECK();
120117
}
121118

122-
template<typename Kernel_traits, bool Is_causal>
119+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias>
123120
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
124121
#ifndef FLASHATTENTION_DISABLE_BACKWARD
125-
run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal>(params, stream);
122+
run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal, Has_mask, Has_bias>(params, stream);
126123
#endif
127124
}
128125

129-
template<typename T, bool Is_causal>
126+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
130127
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
131128
constexpr static int Headdim = 32;
132129
int device;
133130
cudaGetDevice(&device);
134131
int max_smem_per_block;
135132
cudaError status_ = cudaDeviceGetAttribute(
136-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
133+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
134+
);
137135
if (status_ != cudaSuccess) {
138136
C10_CUDA_CHECK(status_);
139137
}
140138
if (max_smem_per_block >= 104 * 1024) { // H100 and A100
141139
// 104KB, 1 CTAs in A100, 2 CTAs in H100.
142-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
140+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
143141
} else { // sm86 and sm89
144142
// 96KB, 1 CTAs in sm86 and sm 89.
145-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal>(params, stream);
143+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
146144
}
147145
}
148146

149-
template<typename T, bool Is_causal>
147+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
150148
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
151149
constexpr static int Headdim = 64;
152150
int device;
153151
cudaGetDevice(&device);
154152
int max_smem_per_block;
155153
cudaError status_ = cudaDeviceGetAttribute(
156-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
154+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
155+
);
157156
if (status_ != cudaSuccess) {
158157
C10_CUDA_CHECK(status_);
159158
}
160159
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
161160
// In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close.
162161
// 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100.
163-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
162+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
164163
// 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
165-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
164+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
166165
// 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
167-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
166+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
168167
} else { // sm86 and sm89
169168
// 88KB, 1 CTAs in sm86 and sm 89.
170-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
169+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
171170
}
172171
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
173172
}
174173

175-
template<typename T, bool Is_causal>
174+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
176175
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
177176
constexpr static int Headdim = 96;
178177
int device;
179178
cudaGetDevice(&device);
180179
int max_smem_per_block;
181180
cudaError status_ = cudaDeviceGetAttribute(
182-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
181+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
182+
);
183183
if (status_ != cudaSuccess) {
184184
C10_CUDA_CHECK(status_);
185185
}
186186
if (max_smem_per_block >= 116 * 1024) { // H100 and A100
187187
// 116KB, 1 CTAs in A100, 1 CTAs in H100.
188-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
188+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
189189
} else { // sm86 and sm89
190190
// 76KB, 1 CTAs in sm86 and sm 89.
191-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
191+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
192192
}
193193
}
194194

195-
template<typename T, bool Is_causal>
195+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
196196
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
197197
constexpr static int Headdim = 128;
198198
int device;
199199
cudaGetDevice(&device);
200200
int max_smem_per_block;
201201
cudaError status_ = cudaDeviceGetAttribute(
202-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
202+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
203+
);
203204
if (status_ != cudaSuccess) {
204205
C10_CUDA_CHECK(status_);
205206
}
206207
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
207208
// 144KB, 1 CTAs in A100, 1 CTAs in H100.
208-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
209+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
209210
} else { // sm86 and sm89
210211
// 80KB, 1 CTAs in sm86 and sm 89.
211-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
212+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
212213
}
213214
}
214215

215-
template<typename T, bool Is_causal>
216+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
216217
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
217218
constexpr static int Headdim = 192;
218219
int device;
219220
cudaGetDevice(&device);
220221
int max_smem_per_block;
221222
cudaError status_ = cudaDeviceGetAttribute(
222-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
223+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
224+
);
223225
if (status_ != cudaSuccess) {
224226
C10_CUDA_CHECK(status_);
225227
}
226228
if (max_smem_per_block >= 136 * 1024) { // H100 and A100
227229
// 136KB, 1 CTAs in A100, 1 CTAs in H100.
228-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
230+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
229231
} else { // sm86 and sm89
230232
// 96KB, 1 CTAs in sm86 and sm 89.
231-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
233+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
232234
}
233235
}
234236

235-
template<typename T, bool Is_causal>
237+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
236238
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
237239
constexpr static int Headdim = 256;
238240
int device;
239241
cudaGetDevice(&device);
240242
int max_smem_per_block;
241243
cudaError status_ = cudaDeviceGetAttribute(
242-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
244+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
245+
);
243246
if (status_ != cudaSuccess) {
244247
C10_CUDA_CHECK(status_);
245248
}
246249
if (max_smem_per_block >= 176 * 1024) { // H100
247250
// 176KB, 1 CTAs in H100.
248-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
251+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
249252
} else if (max_smem_per_block >= 144 * 1024) { // A100
250253
// 144KB, 1 CTAs in A100.
251-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
254+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
252255
} else { // sm86 and sm89
253256
// 96KB, 1 CTAs in sm86 and sm 89.
254-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
257+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
255258
}
256259
}
257260

0 commit comments

Comments
 (0)