@@ -31,17 +31,17 @@ namespace FLASH_NAMESPACE {
3131template <typename Kernel_traits, __VA_ARGS__> \
3232__global__ void kernelName (KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
3333
34- DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Is_even_M, bool Is_even_K) {
34+ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K) {
3535 #if defined(ARCH_SUPPORTS_FLASH)
36- FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Is_even_M, Is_even_K>(params);
36+ FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K>(params);
3737 #else
3838 FLASH_UNSUPPORTED_ARCH
3939 #endif
4040}
4141
42- DEFINE_FLASH_BACKWARD_KERNEL (flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
42+ DEFINE_FLASH_BACKWARD_KERNEL (flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
4343 #if defined(ARCH_SUPPORTS_FLASH)
44- FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap>(params);
44+ FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap>(params);
4545 #else
4646 FLASH_UNSUPPORTED_ARCH
4747 #endif
@@ -68,7 +68,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
6868 FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params);
6969}
7070
71- template <typename Kernel_traits, bool Is_causal>
71+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias >
7272void run_flash_bwd_seqk_parallel (Flash_bwd_params ¶ms, cudaStream_t stream) {
7373 const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1 ) / Kernel_traits::kBlockM ;
7474 dim3 grid_m (num_m_block, params.b , params.h );
@@ -98,11 +98,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
9898 SOFTCAP_SWITCH (params.softcap > 0.0 , Is_softcap, [&] {
9999 // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
100100 // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
101- auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, Is_softcap>;
102- // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
103- if (smem_size_dq_dk_dv >= 48 * 1024 ) {
104- C10_CUDA_CHECK (cudaFuncSetAttribute (
105- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
101+ auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, Has_mask, Has_bias, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128 , IsEvenKConst, Is_softcap>;
102+ if (smem_size_dq_dk_dv >= 48 * 1024 ) {
103+ C10_CUDA_CHECK (cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
106104 }
107105 kernel<<<grid_n, Kernel_traits::kNThreads , smem_size_dq_dk_dv, stream>>>(params);
108106 C10_CUDA_KERNEL_LAUNCH_CHECK ();
@@ -112,146 +110,151 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
112110
113111 auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
114112 if (Kernel_traits::kSmemdQSize >= 48 * 1024 ) {
115- C10_CUDA_CHECK (cudaFuncSetAttribute (
116- kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize ));
113+ C10_CUDA_CHECK (cudaFuncSetAttribute (kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize ));
117114 }
118115 kernel_dq<<<grid_m, Kernel_traits::kNThreads , Kernel_traits::kSmemdQSize , stream>>>(params, !params.deterministic ? 1 : gridDimx);
119116 C10_CUDA_KERNEL_LAUNCH_CHECK ();
120117}
121118
122- template <typename Kernel_traits, bool Is_causal>
119+ template <typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias >
123120void run_flash_bwd (Flash_bwd_params ¶ms, cudaStream_t stream) {
124121#ifndef FLASHATTENTION_DISABLE_BACKWARD
125- run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal>(params, stream);
122+ run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal, Has_mask, Has_bias >(params, stream);
126123#endif
127124}
128125
129- template <typename T, bool Is_causal>
126+ template <typename T, bool Is_causal, bool Has_mask, bool Has_bias >
130127void run_mha_bwd_hdim32 (Flash_bwd_params ¶ms, cudaStream_t stream) {
131128 constexpr static int Headdim = 32 ;
132129 int device;
133130 cudaGetDevice (&device);
134131 int max_smem_per_block;
135132 cudaError status_ = cudaDeviceGetAttribute (
136- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
133+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
134+ );
137135 if (status_ != cudaSuccess) {
138136 C10_CUDA_CHECK (status_);
139137 }
140138 if (max_smem_per_block >= 104 * 1024 ) { // H100 and A100
141139 // 104KB, 1 CTAs in A100, 2 CTAs in H100.
142- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
140+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , false , T>, Is_causal, Has_mask, Has_bias >(params, stream);
143141 } else { // sm86 and sm89
144142 // 96KB, 1 CTAs in sm86 and sm 89.
145- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , true , T>, Is_causal>(params, stream);
143+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , true , T>, Is_causal, Has_mask, Has_bias >(params, stream);
146144 }
147145}
148146
149- template <typename T, bool Is_causal>
147+ template <typename T, bool Is_causal, bool Has_mask, bool Has_bias >
150148void run_mha_bwd_hdim64 (Flash_bwd_params ¶ms, cudaStream_t stream) {
151149 constexpr static int Headdim = 64 ;
152150 int device;
153151 cudaGetDevice (&device);
154152 int max_smem_per_block;
155153 cudaError status_ = cudaDeviceGetAttribute (
156- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
154+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
155+ );
157156 if (status_ != cudaSuccess) {
158157 C10_CUDA_CHECK (status_);
159158 }
160159 if (max_smem_per_block >= 144 * 1024 ) { // H100 and A100
161160 // In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close.
162161 // 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100.
163- // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
162+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias >(params, stream);
164163 // 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
165- // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
164+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal, Has_mask, Has_bias >(params, stream);
166165 // 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
167- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
166+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , false , T>, Is_causal, Has_mask, Has_bias >(params, stream);
168167 } else { // sm86 and sm89
169168 // 88KB, 1 CTAs in sm86 and sm 89.
170- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
169+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal, Has_mask, Has_bias >(params, stream);
171170 }
172171 // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
173172}
174173
175- template <typename T, bool Is_causal>
174+ template <typename T, bool Is_causal, bool Has_mask, bool Has_bias >
176175void run_mha_bwd_hdim96 (Flash_bwd_params ¶ms, cudaStream_t stream) {
177176 constexpr static int Headdim = 96 ;
178177 int device;
179178 cudaGetDevice (&device);
180179 int max_smem_per_block;
181180 cudaError status_ = cudaDeviceGetAttribute (
182- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
181+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
182+ );
183183 if (status_ != cudaSuccess) {
184184 C10_CUDA_CHECK (status_);
185185 }
186186 if (max_smem_per_block >= 116 * 1024 ) { // H100 and A100
187187 // 116KB, 1 CTAs in A100, 1 CTAs in H100.
188- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
188+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal, Has_mask, Has_bias >(params, stream);
189189 } else { // sm86 and sm89
190190 // 76KB, 1 CTAs in sm86 and sm 89.
191- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
191+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal, Has_mask, Has_bias >(params, stream);
192192 }
193193}
194194
195- template <typename T, bool Is_causal>
195+ template <typename T, bool Is_causal, bool Has_mask, bool Has_bias >
196196void run_mha_bwd_hdim128 (Flash_bwd_params ¶ms, cudaStream_t stream) {
197197 constexpr static int Headdim = 128 ;
198198 int device;
199199 cudaGetDevice (&device);
200200 int max_smem_per_block;
201201 cudaError status_ = cudaDeviceGetAttribute (
202- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
202+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
203+ );
203204 if (status_ != cudaSuccess) {
204205 C10_CUDA_CHECK (status_);
205206 }
206207 if (max_smem_per_block >= 144 * 1024 ) { // H100 and A100
207208 // 144KB, 1 CTAs in A100, 1 CTAs in H100.
208- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 2 , false , false , T>, Is_causal>(params, stream);
209+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 2 , false , false , T>, Is_causal, Has_mask, Has_bias >(params, stream);
209210 } else { // sm86 and sm89
210211 // 80KB, 1 CTAs in sm86 and sm 89.
211- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , true , T>, Is_causal>(params, stream);
212+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , true , T>, Is_causal, Has_mask, Has_bias >(params, stream);
212213 }
213214}
214215
215- template <typename T, bool Is_causal>
216+ template <typename T, bool Is_causal, bool Has_mask, bool Has_bias >
216217void run_mha_bwd_hdim192 (Flash_bwd_params ¶ms, cudaStream_t stream) {
217218 constexpr static int Headdim = 192 ;
218219 int device;
219220 cudaGetDevice (&device);
220221 int max_smem_per_block;
221222 cudaError status_ = cudaDeviceGetAttribute (
222- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
223+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
224+ );
223225 if (status_ != cudaSuccess) {
224226 C10_CUDA_CHECK (status_);
225227 }
226228 if (max_smem_per_block >= 136 * 1024 ) { // H100 and A100
227229 // 136KB, 1 CTAs in A100, 1 CTAs in H100.
228- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , false , T>, Is_causal>(params, stream);
230+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , false , T>, Is_causal, Has_mask, Has_bias >(params, stream);
229231 } else { // sm86 and sm89
230232 // 96KB, 1 CTAs in sm86 and sm 89.
231- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , true , true , T>, Is_causal>(params, stream);
233+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , true , true , T>, Is_causal, Has_mask, Has_bias >(params, stream);
232234 }
233235}
234236
235- template <typename T, bool Is_causal>
237+ template <typename T, bool Is_causal, bool Has_mask, bool Has_bias >
236238void run_mha_bwd_hdim256 (Flash_bwd_params ¶ms, cudaStream_t stream) {
237239 constexpr static int Headdim = 256 ;
238240 int device;
239241 cudaGetDevice (&device);
240242 int max_smem_per_block;
241243 cudaError status_ = cudaDeviceGetAttribute (
242- &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
244+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
245+ );
243246 if (status_ != cudaSuccess) {
244247 C10_CUDA_CHECK (status_);
245248 }
246249 if (max_smem_per_block >= 176 * 1024 ) { // H100
247250 // 176KB, 1 CTAs in H100.
248- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , false , T>, Is_causal>(params, stream);
251+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , false , T>, Is_causal, Has_mask, Has_bias >(params, stream);
249252 } else if (max_smem_per_block >= 144 * 1024 ) { // A100
250253 // 144KB, 1 CTAs in A100.
251- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , true , T>, Is_causal>(params, stream);
254+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , true , T>, Is_causal, Has_mask, Has_bias >(params, stream);
252255 } else { // sm86 and sm89
253256 // 96KB, 1 CTAs in sm86 and sm 89.
254- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 32 , 8 , 4 , 1 , 2 , true , true , T>, Is_causal>(params, stream);
257+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 32 , 8 , 4 , 1 , 2 , true , true , T>, Is_causal, Has_mask, Has_bias >(params, stream);
255258 }
256259}
257260
0 commit comments