Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c04b8e9
Adds precomputed ratio fields for mask and bias heads
algo-home Sep 17, 2025
f613ff2
Reorganizes parameter assignments and adds ratio calculations
algo-home Sep 17, 2025
b54aeb7
Simplifies head index calculations for mask and bias
algo-home Sep 17, 2025
7b0e395
Simplifies head index calculation for mask and bias
algo-home Sep 17, 2025
851c643
Adds boolean flags to track mask and bias presence
algo-home Sep 17, 2025
26d8699
Updates mask and bias parameter comments to reflect correct head counts
algo-home Sep 18, 2025
58c9cf0
Adds conditional mask and bias support to flash attention
algo-home Sep 18, 2025
a1c52b9
Enhances dgrad parameter handling by adding has_mask and has_bias fla…
algo-home Sep 18, 2025
2dc07f1
Makes mask and bias parameters optional in MHA functions
algo-home Sep 18, 2025
8c9e21f
Makes mask and bias parameters optional in flash attention
algo-home Sep 18, 2025
443792e
Extends Flash attention to support mask and bias parameters
algo-home Sep 18, 2025
8a6e78c
Adds mask and bias support to kernel generation
algo-home Sep 18, 2025
b3b6f80
Adds mask and bias support to MHA function templates
algo-home Sep 18, 2025
aa36f57
Adds mask and bias support to attention kernels
algo-home Sep 18, 2025
1e51290
Adds compile-time masking and bias optimizations
algo-home Sep 18, 2025
fdda8b5
Adds mask and bias support to flash attention kernels
algo-home Sep 18, 2025
c8be594
Adds mask and bias support to flash attention backward kernels
algo-home Sep 18, 2025
da8d7ed
Adds mask and bias template parameters to backward kernels
algo-home Sep 18, 2025
8ed228d
Adds mask and bias template parameters to attention kernels
algo-home Sep 18, 2025
f391432
Adds support for 3D mask and bias tensors in attention
algo-home Sep 18, 2025
b919c43
Refactors tensor partitioning for mask and bias in attention computat…
algo-home Sep 19, 2025
266f5e6
Adds optional mask and bias support to kernel traits
algo-home Sep 19, 2025
e0b3d30
Optimizes block size based on attention parameters
algo-home Sep 19, 2025
7860c26
Optimizes kernel dispatch based on mask and bias flags
algo-home Sep 19, 2025
f00d1ff
Adds conditional compilation for mask and bias support
algo-home Sep 19, 2025
6ef3de8
Optimizes backward kernel with compile-time checks for mask and bias
algo-home Sep 19, 2025
3884cfe
Adds const qualifiers to mask and bias parameters
algo-home Sep 19, 2025
ca66b6c
Fixes bias gradient handling for 3D bias tensors
algo-home Sep 19, 2025
a61db8c
Refines documentation for attention_mask and attention_bias parameter…
algo-home Sep 19, 2025
87ce7cc
Removes default initialization of attention_bias in _flash_dynamic_ma…
algo-home Sep 19, 2025
77edcb0
Fixes bias tensor initialization in FlashDMAttnFunc to handle None case
algo-home Sep 19, 2025
96d6da0
Refactors CUDA extension sources in setup.py to use glob for dynamic …
algo-home Sep 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
402 changes: 249 additions & 153 deletions csrc/flash_dmattn/flash_api.cpp

Large diffs are not rendered by default.

23 changes: 15 additions & 8 deletions csrc/flash_dmattn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ struct QKV_params {
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
int h_h_k_ratio; // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Mask_params {
void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len]
void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_mask_heads, query_len, key_len]

// The stride of the attention mask tensors.
index_t mask_batch_stride; // Stride between batches of attention mask
Expand All @@ -53,12 +53,15 @@ struct Mask_params {

// The number of heads in the mask.
int h_mask;
int h_h_mask_ratio; // precompute h / h_mask

bool has_mask;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Bias_params {
void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len]
void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_bias_heads, query_len, key_len]

// The stride of the attention bias tensor.
index_t bias_batch_stride; // Stride between batches of attention bias
Expand All @@ -67,13 +70,16 @@ struct Bias_params {

// The number of heads in the bias.
int h_bias;
int h_h_bias_ratio; // precompute h / h_bias

bool has_bias;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params {

// The O matrix (output).
// The O matrix.
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;

Expand All @@ -90,7 +96,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
void * __restrict__ softmax_lseaccum_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, total_q, total_k;

// The scaling factors for the kernel.
float scale_softmax;
Expand All @@ -105,6 +111,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;

// TODO: block mask for less memory usage
int *__restrict__ blockmask;

// The K_new and V_new matrices.
Expand Down Expand Up @@ -192,9 +199,9 @@ struct Flash_bwd_params : public Flash_fwd_params {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal, bool Has_mask, bool Has_bias> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

} // namespace FLASH_NAMESPACE
284 changes: 171 additions & 113 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h

Large diffs are not rendered by default.

85 changes: 44 additions & 41 deletions csrc/flash_dmattn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ namespace FLASH_NAMESPACE {
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Is_even_M, bool Is_even_K) {
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Is_even_M, Is_even_K>(params);
FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_M, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap>(params);
FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_causal, Has_mask, Has_bias, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
Expand All @@ -68,7 +68,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params);
}

template<typename Kernel_traits, bool Is_causal>
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
Expand Down Expand Up @@ -98,11 +98,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_causal, Has_mask, Has_bias, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand All @@ -112,146 +110,151 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)

auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
}
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template<typename Kernel_traits, bool Is_causal>
template<typename Kernel_traits, bool Is_causal, bool Has_mask, bool Has_bias>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal>(params, stream);
run_flash_bwd_seqk_parallel<Kernel_traits, Is_causal, Has_mask, Has_bias>(params, stream);
#endif
}

template<typename T, bool Is_causal>
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 104 * 1024) { // H100 and A100
// 104KB, 1 CTAs in A100, 2 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
} else { // sm86 and sm89
// 96KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
}
}

template<typename T, bool Is_causal>
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
// In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close.
// 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
// 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
// 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
} else { // sm86 and sm89
// 88KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
}
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
}

template<typename T, bool Is_causal>
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 116 * 1024) { // H100 and A100
// 116KB, 1 CTAs in A100, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
} else { // sm86 and sm89
// 76KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
}
}

template<typename T, bool Is_causal>
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
// 144KB, 1 CTAs in A100, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
} else { // sm86 and sm89
// 80KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
}
}

template<typename T, bool Is_causal>
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 136 * 1024) { // H100 and A100
// 136KB, 1 CTAs in A100, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
} else { // sm86 and sm89
// 96KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
}
}

template<typename T, bool Is_causal>
template<typename T, bool Is_causal, bool Has_mask, bool Has_bias>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device
);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 176 * 1024) { // H100
// 176KB, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal, Has_mask, Has_bias>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) { // A100
// 144KB, 1 CTAs in A100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
} else { // sm86 and sm89
// 96KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal, Has_mask, Has_bias>(params, stream);
}
}

Expand Down
Loading