Skip to content

Commit 483b726

Browse files
authored
Merge pull request #170 from SmallDoges/make-mask/bias-optional
[FEATURE SUPPORT] Optional mask/bias (3D & 4D)
2 parents 1cca349 + 96d6da0 commit 483b726

File tree

301 files changed

+4069
-1066
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

301 files changed

+4069
-1066
lines changed

csrc/flash_dmattn/flash_api.cpp

Lines changed: 249 additions & 153 deletions
Large diffs are not rendered by default.

csrc/flash_dmattn/src/flash.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ struct QKV_params {
3838
int h, h_k;
3939
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
4040
// different from nheads (query).
41-
int h_h_k_ratio; // precompute h / h_k,
41+
int h_h_k_ratio; // precompute h / h_k,
4242
};
4343

4444
////////////////////////////////////////////////////////////////////////////////////////////////////
4545

4646
struct Mask_params {
47-
void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len]
47+
void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_mask_heads, query_len, key_len]
4848

4949
// The stride of the attention mask tensors.
5050
index_t mask_batch_stride; // Stride between batches of attention mask
@@ -53,12 +53,15 @@ struct Mask_params {
5353

5454
// The number of heads in the mask.
5555
int h_mask;
56+
int h_h_mask_ratio; // precompute h / h_mask
57+
58+
bool has_mask;
5659
};
5760

5861
////////////////////////////////////////////////////////////////////////////////////////////////////
5962

6063
struct Bias_params {
61-
void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len]
64+
void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_bias_heads, query_len, key_len]
6265

6366
// The stride of the attention bias tensor.
6467
index_t bias_batch_stride; // Stride between batches of attention bias
@@ -67,13 +70,16 @@ struct Bias_params {
6770

6871
// The number of heads in the bias.
6972
int h_bias;
73+
int h_h_bias_ratio; // precompute h / h_bias
74+
75+
bool has_bias;
7076
};
7177

7278
////////////////////////////////////////////////////////////////////////////////////////////////////
7379

7480
struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params {
7581

76-
// The O matrix (output).
82+
// The O matrix.
7783
void * __restrict__ o_ptr;
7884
void * __restrict__ oaccum_ptr;
7985

@@ -90,7 +96,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
9096
void * __restrict__ softmax_lseaccum_ptr;
9197

9298
// The dimensions.
93-
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
99+
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, total_q, total_k;
94100

95101
// The scaling factors for the kernel.
96102
float scale_softmax;
@@ -105,6 +111,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
105111
// If provided, the actual length of each k sequence.
106112
int * __restrict__ seqused_k;
107113

114+
// TODO: block mask for less memory usage
108115
int *__restrict__ blockmask;
109116

110117
// The K_new and V_new matrices.
@@ -192,9 +199,9 @@ struct Flash_bwd_params : public Flash_fwd_params {
192199

193200
////////////////////////////////////////////////////////////////////////////////////////////////////
194201

195-
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
196-
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
202+
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
203+
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
197204

198-
template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
205+
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
199206

200207
} // namespace FLASH_NAMESPACE

csrc/flash_dmattn/src/flash_bwd_kernel.h

Lines changed: 171 additions & 113 deletions
Large diffs are not rendered by default.

csrc/flash_dmattn/src/flash_bwd_launch_template.h

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,17 @@ namespace FLASH_NAMESPACE {
3131
template<typename Kernel_traits, __VA_ARGS__> \
3232
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
3333

34-
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Is_even_M, bool Is_even_K) {
34+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K) {
3535
#if defined(ARCH_SUPPORTS_FLASH)
36-
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Is_even_M, Is_even_K>(params);
36+
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K>(params);
3737
#else
3838
FLASH_UNSUPPORTED_ARCH
3939
#endif
4040
}
4141

42-
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
42+
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
4343
#if defined(ARCH_SUPPORTS_FLASH)
44-
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap>(params);
44+
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap>(params);
4545
#else
4646
FLASH_UNSUPPORTED_ARCH
4747
#endif
@@ -68,7 +68,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
6868
FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params);
6969
}
7070

71-
template<typename Kernel_traits, bool Is_causal>
71+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias>
7272
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
7373
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
7474
dim3 grid_m(num_m_block, params.b, params.h);
@@ -98,11 +98,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
9898
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
9999
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
100100
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
101-
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
102-
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
103-
if (smem_size_dq_dk_dv >= 48 * 1024) {
104-
C10_CUDA_CHECK(cudaFuncSetAttribute(
105-
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
101+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, Has_mask, Has_bias, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
102+
if (smem_size_dq_dk_dv >= 48 * 1024) {
103+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
106104
}
107105
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
108106
C10_CUDA_KERNEL_LAUNCH_CHECK();
@@ -112,146 +110,151 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
112110

113111
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
114112
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
115-
C10_CUDA_CHECK(cudaFuncSetAttribute(
116-
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
113+
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
117114
}
118115
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
119116
C10_CUDA_KERNEL_LAUNCH_CHECK();
120117
}
121118

122-
template<typename Kernel_traits, bool Is_causal>
119+
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias>
123120
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
124121
#ifndef FLASHATTENTION_DISABLE_BACKWARD
125-
run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal>(params, stream);
122+
run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal, Has_mask, Has_bias>(params, stream);
126123
#endif
127124
}
128125

129-
template<typename T, bool Is_causal>
126+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
130127
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
131128
constexpr static int Headdim = 32;
132129
int device;
133130
cudaGetDevice(&device);
134131
int max_smem_per_block;
135132
cudaError status_ = cudaDeviceGetAttribute(
136-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
133+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
134+
);
137135
if (status_ != cudaSuccess) {
138136
C10_CUDA_CHECK(status_);
139137
}
140138
if (max_smem_per_block >= 104 * 1024) { // H100 and A100
141139
// 104KB, 1 CTAs in A100, 2 CTAs in H100.
142-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
140+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
143141
} else { // sm86 and sm89
144142
// 96KB, 1 CTAs in sm86 and sm 89.
145-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal>(params, stream);
143+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
146144
}
147145
}
148146

149-
template<typename T, bool Is_causal>
147+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
150148
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
151149
constexpr static int Headdim = 64;
152150
int device;
153151
cudaGetDevice(&device);
154152
int max_smem_per_block;
155153
cudaError status_ = cudaDeviceGetAttribute(
156-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
154+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
155+
);
157156
if (status_ != cudaSuccess) {
158157
C10_CUDA_CHECK(status_);
159158
}
160159
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
161160
// In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close.
162161
// 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100.
163-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
162+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
164163
// 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
165-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
164+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
166165
// 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
167-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
166+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
168167
} else { // sm86 and sm89
169168
// 88KB, 1 CTAs in sm86 and sm 89.
170-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
169+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
171170
}
172171
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
173172
}
174173

175-
template<typename T, bool Is_causal>
174+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
176175
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
177176
constexpr static int Headdim = 96;
178177
int device;
179178
cudaGetDevice(&device);
180179
int max_smem_per_block;
181180
cudaError status_ = cudaDeviceGetAttribute(
182-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
181+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
182+
);
183183
if (status_ != cudaSuccess) {
184184
C10_CUDA_CHECK(status_);
185185
}
186186
if (max_smem_per_block >= 116 * 1024) { // H100 and A100
187187
// 116KB, 1 CTAs in A100, 1 CTAs in H100.
188-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
188+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
189189
} else { // sm86 and sm89
190190
// 76KB, 1 CTAs in sm86 and sm 89.
191-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
191+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
192192
}
193193
}
194194

195-
template<typename T, bool Is_causal>
195+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
196196
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
197197
constexpr static int Headdim = 128;
198198
int device;
199199
cudaGetDevice(&device);
200200
int max_smem_per_block;
201201
cudaError status_ = cudaDeviceGetAttribute(
202-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
202+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
203+
);
203204
if (status_ != cudaSuccess) {
204205
C10_CUDA_CHECK(status_);
205206
}
206207
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
207208
// 144KB, 1 CTAs in A100, 1 CTAs in H100.
208-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
209+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
209210
} else { // sm86 and sm89
210211
// 80KB, 1 CTAs in sm86 and sm 89.
211-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
212+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
212213
}
213214
}
214215

215-
template<typename T, bool Is_causal>
216+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
216217
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
217218
constexpr static int Headdim = 192;
218219
int device;
219220
cudaGetDevice(&device);
220221
int max_smem_per_block;
221222
cudaError status_ = cudaDeviceGetAttribute(
222-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
223+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
224+
);
223225
if (status_ != cudaSuccess) {
224226
C10_CUDA_CHECK(status_);
225227
}
226228
if (max_smem_per_block >= 136 * 1024) { // H100 and A100
227229
// 136KB, 1 CTAs in A100, 1 CTAs in H100.
228-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
230+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
229231
} else { // sm86 and sm89
230232
// 96KB, 1 CTAs in sm86 and sm 89.
231-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
233+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
232234
}
233235
}
234236

235-
template<typename T, bool Is_causal>
237+
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
236238
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
237239
constexpr static int Headdim = 256;
238240
int device;
239241
cudaGetDevice(&device);
240242
int max_smem_per_block;
241243
cudaError status_ = cudaDeviceGetAttribute(
242-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
244+
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
245+
);
243246
if (status_ != cudaSuccess) {
244247
C10_CUDA_CHECK(status_);
245248
}
246249
if (max_smem_per_block >= 176 * 1024) { // H100
247250
// 176KB, 1 CTAs in H100.
248-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
251+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
249252
} else if (max_smem_per_block >= 144 * 1024) { // A100
250253
// 144KB, 1 CTAs in A100.
251-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
254+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
252255
} else { // sm86 and sm89
253256
// 96KB, 1 CTAs in sm86 and sm 89.
254-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
257+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
255258
}
256259
}
257260

0 commit comments

Comments
 (0)