From c04b8e9bcfaf3836fc3ff5b4a5d6f54cbf8cdc67 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 19:13:58 +0800 Subject: [PATCH 01/32] Adds precomputed ratio fields for mask and bias heads Introduces h_h_mask_ratio and h_h_bias_ratio fields to precompute head ratios for mask and bias parameters, following the existing pattern used for query/key head ratios. Also adds total_k dimension field and includes TODO comment for potential block mask memory optimization. --- csrc/flash_dmattn/src/flash.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index a533cc3..f42b21a 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -38,7 +38,7 @@ struct QKV_params { int h, h_k; // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). - int h_h_k_ratio; // precompute h / h_k, + int h_h_k_ratio; // precompute h / h_k, }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -53,6 +53,7 @@ struct Mask_params { // The number of heads in the mask. int h_mask; + int h_h_mask_ratio; // precompute h / h_mask }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -67,13 +68,14 @@ struct Bias_params { // The number of heads in the bias. int h_bias; + int h_h_bias_ratio; // precompute h / h_bias }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { - // The O matrix (output). + // The O matrix. void * __restrict__ o_ptr; void * __restrict__ oaccum_ptr; @@ -90,7 +92,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, total_q, total_k; // The scaling factors for the kernel. float scale_softmax; @@ -105,6 +107,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; + // TODO: block mask for less memory usage int *__restrict__ blockmask; // The K_new and V_new matrices. From f613ff261ecf8118b83b308f1fce49b3bca61f6c Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 19:25:58 +0800 Subject: [PATCH 02/32] Reorganizes parameter assignments and adds ratio calculations Reorders stride and head parameter assignments for better logical grouping and consistency between forward and backward pass implementations. Adds missing ratio calculations for mask and bias heads to complement existing head-to-key ratio computation. Fixes indentation inconsistency in batch stride calculations. --- csrc/flash_dmattn/flash_api.cpp | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 4108cb4..eb94947 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -69,16 +69,16 @@ void set_params_fprop( // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.mask_row_stride = mask.stride(-2); - params.bias_row_stride = bias.stride(-2); - params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); + params.k_row_stride = k.stride(-3); params.k_head_stride = k.stride(-2); + params.v_row_stride = v.stride(-3); params.v_head_stride = v.stride(-2); params.mask_head_stride = mask.stride(-3); + params.mask_row_stride = mask.stride(-2); params.bias_head_stride = bias.stride(-3); + params.bias_row_stride = bias.stride(-2); + params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); if (cu_seqlens_q_d == nullptr) { @@ -89,8 +89,8 @@ void set_params_fprop( params.bias_batch_stride = bias.stride(0); params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { - params.q_batch_stride *= seqlen_q; - params.o_batch_stride *= seqlen_q; + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; } } @@ -108,9 +108,11 @@ void set_params_fprop( params.b = b; params.h = h; params.h_k = h_k; - params.h_h_k_ratio = h / h_k; params.h_mask = h_mask; params.h_bias = h_bias; + params.h_h_k_ratio = h / h_k; + params.h_h_mask_ratio = h / h_mask; + params.h_h_bias_ratio = h / h_bias; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; @@ -201,20 +203,22 @@ void set_params_dgrad( // Set the pointers and strides. params.do_ptr = dout.data_ptr(); - params.do_row_stride = dout.stride(-3); - params.do_head_stride = dout.stride(-2); params.dq_ptr = dq.data_ptr(); params.dk_ptr = dk.data_ptr(); params.dv_ptr = dv.data_ptr(); params.dbias_ptr = dbias.data_ptr(); + + // All stride are in elements, not bytes. + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); params.dq_row_stride = dq.stride(-3); - params.dk_row_stride = dk.stride(-3); - params.dv_row_stride = dv.stride(-3); - params.dbias_row_stride = dbias.stride(-2); params.dq_head_stride = dq.stride(-2); + params.dk_row_stride = dk.stride(-3); params.dk_head_stride = dk.stride(-2); + params.dv_row_stride = dv.stride(-3); params.dv_head_stride = dv.stride(-2); params.dbias_head_stride = dbias.stride(-3); + params.dbias_row_stride = dbias.stride(-2); if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = dout.stride(0); From b54aeb73e8952c7925e9ade5232655f95280b895 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 19:52:03 +0800 Subject: [PATCH 03/32] Simplifies head index calculations for mask and bias Removes redundant conditional logic for computing head indices by directly using the ratio-based calculations inline. Previously used intermediate variables with complex conditional expressions that duplicated the same logic pattern. Now directly computes head indices using the division by ratio parameters, making the code more readable and eliminating unnecessary variables. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 5b701bf..09e031d 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -136,8 +136,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // might save us 1 register (we just need n_block instead of both n_block and n_block_max). const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); - const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); // Global memory tensor configuration Tensor mQ = make_tensor( @@ -176,7 +174,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); Tensor gMask = local_tile( - mMask(h_idx_mask, _, _), + mMask(bidh / params.h_h_mask_ratio, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) @@ -186,7 +184,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.bias_head_stride, params.bias_row_stride, _1{}) ); Tensor gBias = local_tile( - mBias(h_idx_bias, _, _), + mBias(bidh / params.h_h_bias_ratio, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) @@ -871,18 +869,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t col_offset_mask = (block_table == nullptr) ? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache) - + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN + + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache) - + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; - const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); + + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; const index_t col_offset_bias = (block_table == nullptr) ? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache) - + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN + + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache) - + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; + + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; // Global memory tensor configuration Tensor mQ = make_tensor( From 7b0e395a1c66fa92d595c27a246f4f864f002de2 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 17 Sep 2025 19:52:26 +0800 Subject: [PATCH 04/32] Simplifies head index calculation for mask and bias Removes conditional logic for computing head indices and replaces it with direct ratio-based calculations using h_h_mask_ratio and h_h_bias_ratio parameters. This eliminates the need for intermediate variables and conditional branches, making the code more straightforward and potentially improving performance. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index dc93b76..05ec18f 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -107,12 +107,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb) - + h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN; - const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); + + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN; const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb) - + h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN; + + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN; const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb) + bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) From 851c643fec7734b5b2d5e6563ead7c2fabcc9b02 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 00:22:07 +0800 Subject: [PATCH 05/32] Adds boolean flags to track mask and bias presence Introduces has_mask and has_bias boolean fields to Mask_params and Bias_params structures respectively. These flags enable runtime detection of whether mask or bias parameters are present, improving conditional logic handling and potentially optimizing performance by avoiding unnecessary processing when these optional components are not used. --- csrc/flash_dmattn/src/flash.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index f42b21a..ac28695 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -54,6 +54,8 @@ struct Mask_params { // The number of heads in the mask. int h_mask; int h_h_mask_ratio; // precompute h / h_mask + + bool has_mask; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -69,6 +71,8 @@ struct Bias_params { // The number of heads in the bias. int h_bias; int h_h_bias_ratio; // precompute h / h_bias + + bool has_bias; }; //////////////////////////////////////////////////////////////////////////////////////////////////// From 26d8699a0d35b7214e31c350db8f7b4e34c377a1 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 13:40:40 +0800 Subject: [PATCH 06/32] Updates mask and bias parameter comments to reflect correct head counts --- csrc/flash_dmattn/src/flash.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index ac28695..d456de1 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -44,7 +44,7 @@ struct QKV_params { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Mask_params { - void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] + void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_mask_heads, query_len, key_len] // The stride of the attention mask tensors. index_t mask_batch_stride; // Stride between batches of attention mask @@ -61,7 +61,7 @@ struct Mask_params { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Bias_params { - void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] + void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_bias_heads, query_len, key_len] // The stride of the attention bias tensor. index_t bias_batch_stride; // Stride between batches of attention bias From 58c9cf0122130ca070a2aadeae111baa4a5dbec1 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 13:42:21 +0800 Subject: [PATCH 07/32] Adds conditional mask and bias support to flash attention Introduces optional mask and bias parameters to prevent accessing null tensors when these features are disabled. Previously, mask and bias tensors were always accessed regardless of whether they contained valid data, which could cause errors or undefined behavior. Now uses conditional checks to only access tensor data and stride information when the corresponding features are actually enabled, improving robustness and allowing for optional mask/bias functionality. --- csrc/flash_dmattn/flash_api.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index eb94947..4a9b6f5 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -50,6 +50,8 @@ void set_params_fprop( float softmax_scale, bool is_causal, const float softcap, + bool has_mask, + bool has_bias, bool seqlenq_ngroups_swapped=false, const bool unpadded_lse=false ) { @@ -63,8 +65,8 @@ void set_params_fprop( params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); - params.mask_ptr = mask.data_ptr(); - params.bias_ptr = bias.data_ptr(); + params.mask_ptr = has_mask ? mask.data_ptr() : nullptr; + params.bias_ptr = has_bias ? bias.data_ptr() : nullptr; params.o_ptr = out.data_ptr(); // All stride are in elements, not bytes. @@ -74,10 +76,10 @@ void set_params_fprop( params.k_head_stride = k.stride(-2); params.v_row_stride = v.stride(-3); params.v_head_stride = v.stride(-2); - params.mask_head_stride = mask.stride(-3); - params.mask_row_stride = mask.stride(-2); - params.bias_head_stride = bias.stride(-3); - params.bias_row_stride = bias.stride(-2); + params.mask_head_stride = has_mask ? mask.stride(-3) : 0; + params.mask_row_stride = has_mask ? mask.stride(-2) : 0; + params.bias_head_stride = has_bias ? bias.stride(-3) : 0; + params.bias_row_stride = has_bias ? bias.stride(-2) : 0; params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); @@ -85,11 +87,13 @@ void set_params_fprop( params.q_batch_stride = q.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); - params.mask_batch_stride = mask.stride(0); - params.bias_batch_stride = bias.stride(0); + params.mask_batch_stride = has_mask ? mask.stride(0) : 0; + params.bias_batch_stride = has_bias ? bias.stride(0) : 0; params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { params.q_batch_stride *= seqlen_q; + params.mask_batch_stride *= seqlen_q; + params.bias_batch_stride *= seqlen_q; params.o_batch_stride *= seqlen_q; } } @@ -136,6 +140,8 @@ void set_params_fprop( } params.is_causal = is_causal; + params.has_mask = has_mask; + params.has_bias = has_bias; params.is_seqlens_k_cumulative = true; #ifdef FLASHATTENTION_DISABLE_UNEVEN_K From a1c52b99bf550721b73cbfeaf5cd97c41118d58b Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 13:42:49 +0800 Subject: [PATCH 08/32] Enhances dgrad parameter handling by adding has_mask and has_bias flags for conditional bias management --- csrc/flash_dmattn/flash_api.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 4a9b6f5..2e58d1e 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -188,6 +188,8 @@ void set_params_dgrad( float softmax_scale, bool is_causal, const float softcap, + bool has_mask, + bool has_bias, bool deterministic, const bool unpadded_lse ) { @@ -203,6 +205,8 @@ void set_params_dgrad( softmax_scale, is_causal, softcap, + has_mask, + has_bias, false, // seqlenq_ngroups_swapped unpadded_lse ); @@ -212,7 +216,7 @@ void set_params_dgrad( params.dq_ptr = dq.data_ptr(); params.dk_ptr = dk.data_ptr(); params.dv_ptr = dv.data_ptr(); - params.dbias_ptr = dbias.data_ptr(); + params.dbias_ptr = has_bias ? dbias.data_ptr() : nullptr; // All stride are in elements, not bytes. params.do_row_stride = dout.stride(-3); @@ -223,15 +227,15 @@ void set_params_dgrad( params.dk_head_stride = dk.stride(-2); params.dv_row_stride = dv.stride(-3); params.dv_head_stride = dv.stride(-2); - params.dbias_head_stride = dbias.stride(-3); - params.dbias_row_stride = dbias.stride(-2); + params.dbias_head_stride = has_bias ? dbias.stride(-3) : 0; + params.dbias_row_stride = has_bias ? dbias.stride(-2) : 0; if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = dout.stride(0); params.dq_batch_stride = dq.stride(0); params.dk_batch_stride = dk.stride(0); params.dv_batch_stride = dv.stride(0); - params.dbias_batch_stride = dbias.stride(0); + params.dbias_batch_stride = has_bias ? dbias.stride(0) : 0; } params.dq_accum_ptr = dq_accum_d; From 2dc07f16736d15b519bd1c152baf821288b951f0 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 13:43:42 +0800 Subject: [PATCH 09/32] Makes mask and bias parameters optional in MHA functions Changes mask and bias parameters from required tensor references to optional tensor references in both forward and backward multi-head attention functions. Improves API flexibility by allowing these attention modifiers to be omitted when not needed, reducing memory overhead and simplifying function calls for basic attention operations. Updates parameter comments to use consistent formatting with curly brace notation for dimension alternatives. --- csrc/flash_dmattn/flash_api.cpp | 36 ++++++++++++++++----------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 2e58d1e..9804738 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -358,12 +358,12 @@ std::tuple set_params_splitkv( std::vector mha_fwd( - at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k - at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k, or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k - std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k + std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float softmax_scale, bool is_causal, const float softcap, @@ -775,18 +775,18 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { std::vector mha_bwd( - const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) - const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x num_heads x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k - const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x seqlen_q - std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dbias_, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k + const std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dbias_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k const float softmax_scale, const bool is_causal, const float softcap, From 8c9e21f10f8c37c474be112e8505d1786b888093 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 13:45:45 +0800 Subject: [PATCH 10/32] Makes mask and bias parameters optional in flash attention Converts mandatory mask and bias parameters to optional parameters by wrapping them in std::optional. Adds proper validation and initialization logic to handle cases where mask or bias are not provided, creating empty tensors as placeholders when needed. Updates all related tensor operations, shape checking, and kernel parameter passing to conditionally process these optional inputs throughout both forward and backward passes. --- csrc/flash_dmattn/flash_api.cpp | 262 ++++++++++++++++++++------------ 1 file changed, 164 insertions(+), 98 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 9804738..b375268 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -381,16 +381,35 @@ mha_fwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); - TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + auto opts = q.options(); + + bool has_mask = mask_.has_value(); + at::Tensor mask; + if (has_mask) { + mask = mask_.value(); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); + CHECK_DEVICE(mask); + TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + } else { + mask = torch::empty({0}, opts); + } + bool has_bias = bias_.has_value(); + at::Tensor bias; + if (has_bias) { + bias = bias_.value(); + TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); + CHECK_DEVICE(bias); + TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + } else { + bias = torch::empty({0}, opts); + } const auto sizes = q.sizes(); @@ -400,14 +419,19 @@ mha_fwd( const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - int num_heads_mask = mask.size(1); - int num_heads_bias = bias.size(1); + int num_heads_mask = has_mask ? mask.size(1) : 1; + int num_heads_bias = has_bias ? bias.size(1) : 1; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); - TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + if (has_mask) { + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + } + if (has_bias) { + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + } // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } @@ -420,22 +444,26 @@ mha_fwd( const int orig_num_heads_bias = num_heads_bias; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); - if (num_heads_mask == 1) { - mask = mask.expand({batch_size, 1, ngroups, seqlen_k}); - } else if (num_heads_mask == num_heads_k) { - mask = mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}); - } else { // num_heads_mask == num_heads - mask = mask.reshape({batch_size, num_heads_k, ngroups, seqlen_k}); + if (has_mask) { + mask = num_heads_mask == 1 + ? mask.expand({batch_size, 1, ngroups, seqlen_k}) + : ( + num_heads_mask == num_heads_k + ? mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}) + : mask.reshape({batch_size, num_heads_k, ngroups, seqlen_k}) + ); } - if (num_heads_bias == 1) { - bias = bias.expand({batch_size, 1, ngroups, seqlen_k}); - } else if (num_heads_bias == num_heads_k) { - bias = bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}); - } else { // num_heads_bias == num_heads - bias = bias.reshape({batch_size, num_heads_k, ngroups, seqlen_k}); + if (has_bias) { + bias = num_heads_bias == 1 + ? bias.expand({batch_size, 1, ngroups, seqlen_k}) + : ( + num_heads_bias == num_heads_k + ? bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}) + : bias.reshape({batch_size, num_heads_k, ngroups, seqlen_k}) + ); } - num_heads_mask = (num_heads_mask == num_heads) ? num_heads_k : num_heads_mask; - num_heads_bias = (num_heads_bias == num_heads) ? num_heads_k : num_heads_bias; + num_heads_mask = has_mask ? ((num_heads_mask == num_heads) ? num_heads_k : num_heads_mask) : 1; + num_heads_bias = has_bias ? ((num_heads_bias == num_heads) ? num_heads_k : num_heads_bias) : 1; seqlen_q = ngroups; num_heads = num_heads_k; } @@ -443,19 +471,23 @@ mha_fwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - if (num_heads_mask == 1) { - CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_mask == num_heads_k) { - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); + if (has_mask) { + if (num_heads_mask == 1) { + CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_mask == num_heads_k) { + CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); + } } - if (num_heads_bias == 1) { - CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); + if (has_bias) { + if (num_heads_bias == 1) { + CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); + } } at::Tensor out; @@ -477,8 +509,6 @@ mha_fwd( const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - auto opts = q.options(); - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor p; @@ -504,7 +534,9 @@ mha_fwd( softmax_lse.data_ptr(), softmax_scale, is_causal, - softcap + softcap, + has_mask, + has_bias ); // Keep references to these tensors to extend their lifetime @@ -527,15 +559,15 @@ mha_fwd( out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); - if (orig_num_heads_mask == 1 || orig_num_heads_mask == num_heads_k) { - mask = mask.narrow(2, 0, 1); - } else { // orig_num_heads_mask == num_heads - mask = mask.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + if (has_mask) { + mask = (orig_num_heads_mask == 1 || orig_num_heads_mask == num_heads_k) + ? mask.narrow(2, 0, 1) + : mask.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); } - if (orig_num_heads_bias == 1 || orig_num_heads_bias == num_heads_k) { - bias = bias.narrow(2, 0, 1); - } else { // orig_num_heads_bias == num_heads - bias = bias.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + if (has_bias) { + bias = (orig_num_heads_bias == 1 || orig_num_heads_bias == num_heads_k) + ? bias.narrow(2, 0, 1) + : bias.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); } } return {out, softmax_lse, p}; @@ -810,39 +842,62 @@ mha_bwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); - TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - const auto sizes = q.sizes(); auto opts = q.options(); + bool has_mask = mask_.has_value(); + at::Tensor mask; + if (has_mask) { + mask = mask_.value(); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); + CHECK_DEVICE(mask); + TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + } else { + mask = torch::empty({0}, opts); + } + bool has_bias = bias_.has_value(); + at::Tensor bias; + if (has_bias) { + bias = bias_.value(); + TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); + CHECK_DEVICE(bias); + TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + } else { + bias = torch::empty({0}, opts); + } + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; const int seqlen_q = sizes[1]; const int num_heads = sizes[2]; const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - const int num_heads_mask = mask.size(1); - const int num_heads_bias = bias.size(1); + int num_heads_mask = has_mask ? mask.size(1) : 1; + int num_heads_bias = has_bias ? bias.size(1) : 1; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); - TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + if (has_mask) { + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + } + if (has_bias) { + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); @@ -852,19 +907,23 @@ mha_bwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - if (num_heads_mask == 1) { - CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_mask == num_heads_k) { - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); + if (has_mask) { + if (num_heads_mask == 1) { + CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_mask == num_heads_k) { + CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); + } } - if (num_heads_bias == 1) { - CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); + if (has_bias) { + if (num_heads_bias == 1) { + CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); + } } CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); @@ -897,26 +956,30 @@ mha_bwd( } else { dv = torch::empty_like(v); } - if (dbias_.has_value()) { - dbias = dbias_.value(); - TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); - CHECK_DEVICE(dbias); - TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - if (num_heads_bias == 1) { - CHECK_SHAPE(dbias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(dbias, batch_size, num_heads_k, seqlen_q, seqlen_k); + if (has_bias) { + if (dbias_.has_value()) { + dbias = dbias_.value(); + TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); + CHECK_DEVICE(dbias); + TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); + if (num_heads_bias == 1) { + CHECK_SHAPE(dbias, batch_size, 1, seqlen_q, seqlen_k); + } else if (num_heads_bias == num_heads_k) { + CHECK_SHAPE(dbias, batch_size, num_heads_k, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(dbias, batch_size, num_heads, seqlen_q, seqlen_k); + } } else { - CHECK_SHAPE(dbias, batch_size, num_heads, seqlen_q, seqlen_k); + if (num_heads_bias == 1) { + dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); + } else if (num_heads_bias == num_heads_k) { + dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); + } else { + dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } } } else { - if (num_heads_bias == 1) { - dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); - } else if (num_heads_bias == num_heads_k) { - dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); - } else { - dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); - } + dbias = torch::empty({0}, opts); } // bool loop = seqlen_k > blocksize_c; @@ -938,18 +1001,19 @@ mha_bwd( } at::Tensor dk_expanded, dv_expanded, dbias_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - } - if (num_heads_bias != num_heads) { - dbias_expanded = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); - } else { - dbias_expanded = dbias; - } + dk_expanded = num_heads_k != num_heads // MQA / GQA + ? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts) + : dk; + dv_expanded = num_heads_k != num_heads // MQA / GQA + ? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts) + : dv; + dbias_expanded = has_bias + ? ( + num_heads_bias != num_heads // MQA / GQA + ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts) + : dbias + ) + : torch::empty({0}, opts); Flash_bwd_params params; @@ -974,6 +1038,8 @@ mha_bwd( softmax_scale, is_causal, softcap, + has_mask, + has_bias, deterministic, /*unpadded_lse*/false ); @@ -997,7 +1063,7 @@ mha_bwd( at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } // For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads - if (num_heads_bias != num_heads) { + if (has_bias && num_heads_bias != num_heads) { at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}); } From 443792eed41a1226c5d45cc82bef36e48a567e21 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 16:54:11 +0800 Subject: [PATCH 11/32] Extends Flash attention to support mask and bias parameters Refactors template signatures across forward, backward, and split-KV kernels to include additional boolean parameters for mask and bias support. Updates all kernel instantiations to use expanded template parameters, maintaining compatibility with existing causal-only configurations while enabling new combinations of mask and bias features. Removes formatting inconsistencies in include statements and standardizes template parameter ordering across all instantiation files. --- .../flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim128_bf16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim128_bf16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim128_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim128_bf16_sm80.cu | 7 +++---- .../flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim128_fp16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim128_fp16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim128_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim128_fp16_sm80.cu | 7 +++---- .../flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim192_bf16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim192_bf16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim192_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim192_bf16_sm80.cu | 7 +++---- .../flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim192_fp16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim192_fp16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim192_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim192_fp16_sm80.cu | 7 +++---- .../flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim256_bf16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim256_bf16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim256_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim256_bf16_sm80.cu | 7 +++---- .../flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim256_fp16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim256_fp16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim256_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim256_fp16_sm80.cu | 7 +++---- .../flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim32_bf16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim32_bf16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim32_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim32_bf16_sm80.cu | 7 +++---- .../flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim32_fp16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim32_fp16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim32_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim32_fp16_sm80.cu | 7 +++---- .../flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim64_bf16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim64_bf16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim64_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim64_bf16_sm80.cu | 7 +++---- .../flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim64_fp16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim64_fp16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim64_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim64_fp16_sm80.cu | 7 +++---- .../flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim96_bf16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim96_bf16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim96_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim96_bf16_sm80.cu | 7 +++---- .../flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim96_fp16_causal_sm80.cu | 7 +++---- .../flash_bwd_hdim96_fp16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_bwd_hdim96_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_bwd_hdim96_fp16_sm80.cu | 7 +++---- .../flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim128_bf16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim128_bf16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim128_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim128_bf16_sm80.cu | 7 +++---- .../flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim128_fp16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim128_fp16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim128_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim128_fp16_sm80.cu | 7 +++---- .../flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim192_bf16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim192_bf16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim192_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim192_bf16_sm80.cu | 7 +++---- .../flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim192_fp16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim192_fp16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim192_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim192_fp16_sm80.cu | 7 +++---- .../flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim256_bf16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim256_bf16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim256_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim256_bf16_sm80.cu | 7 +++---- .../flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...wd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim256_fp16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim256_fp16_has_bias_sm80.cu | 13 +++++++++++++ ...flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim256_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim256_fp16_sm80.cu | 7 +++---- .../flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim32_bf16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim32_bf16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim32_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim32_bf16_sm80.cu | 7 +++---- .../flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim32_fp16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim32_fp16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim32_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim32_fp16_sm80.cu | 7 +++---- .../flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim64_bf16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim64_bf16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim64_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim64_bf16_sm80.cu | 7 +++---- .../flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim64_fp16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim64_fp16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim64_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim64_fp16_sm80.cu | 7 +++---- .../flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim96_bf16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim96_bf16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim96_bf16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim96_bf16_sm80.cu | 7 +++---- .../flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu | 13 +++++++++++++ ...fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim96_fp16_causal_sm80.cu | 7 +++---- .../flash_fwd_hdim96_fp16_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu | 13 +++++++++++++ .../flash_fwd_hdim96_fp16_has_mask_sm80.cu | 13 +++++++++++++ .../instantiations/flash_fwd_hdim96_fp16_sm80.cu | 7 +++---- ...h_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu | 10 ++++++++++ ...it_hdim128_bf16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...h_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim128_bf16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim128_bf16_has_bias_sm80.cu | 10 ++++++++++ ...fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim128_bf16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim128_bf16_sm80.cu | 5 ++--- ...h_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu | 10 ++++++++++ ...it_hdim128_fp16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...h_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim128_fp16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim128_fp16_has_bias_sm80.cu | 10 ++++++++++ ...fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim128_fp16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim128_fp16_sm80.cu | 5 ++--- ...h_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu | 10 ++++++++++ ...it_hdim192_bf16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...h_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim192_bf16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim192_bf16_has_bias_sm80.cu | 10 ++++++++++ ...fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim192_bf16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim192_bf16_sm80.cu | 5 ++--- ...h_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu | 10 ++++++++++ ...it_hdim192_fp16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...h_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim192_fp16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim192_fp16_has_bias_sm80.cu | 10 ++++++++++ ...fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim192_fp16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim192_fp16_sm80.cu | 5 ++--- ...h_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu | 10 ++++++++++ ...it_hdim256_bf16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...h_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim256_bf16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim256_bf16_has_bias_sm80.cu | 10 ++++++++++ ...fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim256_bf16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim256_bf16_sm80.cu | 5 ++--- ...h_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu | 10 ++++++++++ ...it_hdim256_fp16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...h_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim256_fp16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim256_fp16_has_bias_sm80.cu | 10 ++++++++++ ...fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim256_fp16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim256_fp16_sm80.cu | 5 ++--- ...sh_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu | 10 ++++++++++ ...lit_hdim32_bf16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...sh_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim32_bf16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim32_bf16_has_bias_sm80.cu | 10 ++++++++++ ..._fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim32_bf16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim32_bf16_sm80.cu | 5 ++--- ...sh_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu | 10 ++++++++++ ...lit_hdim32_fp16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...sh_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim32_fp16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim32_fp16_has_bias_sm80.cu | 10 ++++++++++ ..._fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim32_fp16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim32_fp16_sm80.cu | 5 ++--- ...sh_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu | 10 ++++++++++ ...lit_hdim64_bf16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...sh_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim64_bf16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim64_bf16_has_bias_sm80.cu | 10 ++++++++++ ..._fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim64_bf16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim64_bf16_sm80.cu | 5 ++--- ...sh_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu | 10 ++++++++++ ...lit_hdim64_fp16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...sh_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim64_fp16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim64_fp16_has_bias_sm80.cu | 10 ++++++++++ ..._fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim64_fp16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim64_fp16_sm80.cu | 5 ++--- ...sh_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu | 10 ++++++++++ ...lit_hdim96_bf16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...sh_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim96_bf16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim96_bf16_has_bias_sm80.cu | 10 ++++++++++ ..._fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim96_bf16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim96_bf16_sm80.cu | 5 ++--- ...sh_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu | 10 ++++++++++ ...lit_hdim96_fp16_causal_has_mask_has_bias_sm80.cu | 10 ++++++++++ ...sh_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim96_fp16_causal_sm80.cu | 5 ++--- .../flash_fwd_split_hdim96_fp16_has_bias_sm80.cu | 10 ++++++++++ ..._fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim96_fp16_has_mask_sm80.cu | 10 ++++++++++ .../flash_fwd_split_hdim96_fp16_sm80.cu | 5 ++--- 288 files changed, 2784 insertions(+), 264 deletions(-) create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu create mode 100644 csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..5e0883e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..31805fb --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..3851a70 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu index 581443c..5f928d3 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..1e148e9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..5ac4e19 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..a753d58 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu index eae9439..d1f0c9f 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..346c75e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..97c2016 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..b97517c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu index ad5d580..282c744 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..2ca811c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..5c8b491 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..672864c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu index 8d93480..1a2a991 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..cd8fdf3 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..d9625d7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..748ce7a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu index cee3073..1e3ed61 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..d159662 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..91fddab --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..5dc1bc9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu index bb5063b..0868295 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..38d1dc9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..2de0e95 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..4358751 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu index 290187a..de93c33 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..71d6a6e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..d39794d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..b19f9e7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu index 7bef7a3..3825325 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..be3e201 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..e35bf30 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..852286c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu index a28c41c..0a54ab1 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..6d27638 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..7896955 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..7cb82c6 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu index 114faf8..3c41f3c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..05644d3 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..07713a8 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..c64b79e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu index a89a253..1120012 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..4a4d8aa --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..84f8f1f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..b6e0e97 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu index db59281..84d4bfb 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..ec7f221 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..6be8d6b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..f5b0daa --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu index a0cde4f..73c7a18 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..0a757ff --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..14717b5 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..e01a354 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu index 29fb6f8..b82b8b1 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..6b7369b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ba1236b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..db9b308 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu index 885067d..c49252e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..8a4022f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..6f5537a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..be0e3c4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu index 2c35ae2..f2c9f4f 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..36a39e9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..4e735d0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..c332c96 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu index 5dc4084..502c429 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..367a798 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..000cf37 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..4f1cd88 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu index f410065..d959575 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..4c43cd4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..662df9b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..d5be139 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu index d58a0f6..a5a68b7 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..45eb04e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0e47df2 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..b082ab7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu index 47c01c2..2a3f1ee 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..717c2d1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0d64a82 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..3296e56 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu index 4c27c0b..1276b88 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..2d30d81 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..f90667d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..15bab62 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu index da96a97..75ba3ee 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..ed84e18 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ce4da33 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..626aac8 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu index 8e79520..7a2a20e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..80fec3d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..bc85d1d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..24ef408 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu index e1ac513..1a73dba 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..137a44d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..3126e9f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..5ceab71 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu index 2bacfd3..e6a99f1 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..6f00490 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..854c601 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..c137e94 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu index bc6f103..224dd9e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..b7afdcb --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0e8dd9a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..c3d2cda --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu index 48c2d89..14b0091 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..ef80d33 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..c749384 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..82ec511 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu index c67fdce..1c75613 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..e78791d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..05ac29c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..a8c060a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu index e957ce0..457c6ae 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..3467402 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..63172d0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..57d7b81 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu index f53f7f5..5e67bc9 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..35bdefd --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..4b3eaa7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..42f64ee --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu index 805e548..814944a 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..55c5e60 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..eea672d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..e4e0688 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu index 87a6565..3185d87 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..78d4313 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..c7e1c19 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..d67103b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu index ded9ab1..06d28b4 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..8df713d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..abae76d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..a66fb04 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu index c3d9e04..c6bc287 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..1b00dd1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..66ed354 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..6ae0a42 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu index 8780ed4..1b6ad94 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..6d17fa8 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..fcfe162 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..42bebc8 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu index 293e001..48d732e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..7ca3f9a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ab5a247 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..5b5df80 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu index f50910b..d94db09 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..bb16aed --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..f31d044 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..69fb705 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu index 6127386..7c4d01c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..664d742 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..6084e18 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..9228bad --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu index 845ab35..5bb2263 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..9bd62f6 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0c8c28e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..b91ecd2 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu index 94a1301..0fe7203 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..71277e6 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..a54a945 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..6a40c9a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu index f3d4e94..06aa7f2 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..b618f90 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..9b2dec0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..3b65d53 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu index a11cc79..b48cca5 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..708ddce --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..c935579 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..71e686c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu index 1681867..4b69c71 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..ca63d55 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..508e38f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..61b9621 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu index 7b7d184..2ce28f2 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..c2ad7c2 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..d2f46b9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..b605b86 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu index ad62e87..2425d20 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..6b34039 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..bd6ee43 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..663ba2b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu index 92c4344..fca0d6b 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..a8dc858 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ede9c0d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..4eb2c7f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu index 36db0da..04ceeff 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..4b98bf4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..1417fa2 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..cd07586 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu index 50040b5..be4617c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..5c31681 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0344fe1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..f955761 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu index b4118c7..e53a721 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..9bc8f83 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..7cd1336 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..9a24313 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu index cdcfbe9..20e665b 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..198fc1d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..08a187a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..f8bca05 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu index 71e415d..69f7562 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..fd305fe --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..9a144f5 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..3cb985a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu index df4febe..2787367 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..c2ec067 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..82732fc --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..4155c6f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu index 83c8f8a..b36962b 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..307a1a1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ae704e7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..e1fe8a5 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu index d3bbf47..71119c0 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..2adbadc --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..154a54b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..ccf5bc3 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu index 5652982..18316d4 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..fd441c9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..a56610d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..54430a4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu index edac2b6..aca36b9 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..dd3c5ed --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..eef7f06 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..6b15c57 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu index 28ab7ad..6303b8c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..28305b9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..4e98cb7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..e60cd70 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu index 751035e..5bbac09 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..63a0e75 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ba80503 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..0c5168d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu index 502b5cc..aa42659 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..5a70fbc --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..6d539d1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..008478a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu index 3153e17..625ba34 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..2cfd687 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..3eb807c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..8b36fb4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu index 9910f63..8f461b2 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..e144b49 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..3dd585f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..6a2eb0b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu index d498fea..8f6e82d 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..07c8228 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..2e0b252 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..2772cca --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu index a5a713a..f48c612 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..1adfa91 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ee443da --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..d4d4b85 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu index 4cfc36a..a9ac55e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..f25e3d3 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..574bc25 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..c0e431e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu index e89cb76..cca29c5 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..9c676d0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..2797f2d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..83d5a04 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu index 8d72e93..48ff397 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..6353a34 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..8871c82 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..3d54aad --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu index 76ba0c8..d030ef0 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..72d4513 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..2b7be7e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..35c2845 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu index ab07719..4211c14 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..0989b27 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..b09b986 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..6bf4d32 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu index 8d44e28..7469a22 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..33381b0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..7f5d4c0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..4d52f62 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu index 252b468..787308d 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..2e49991 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..e6c856c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..b81d1a9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu index 3eb97b7..1d3a8af 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..68d3a41 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..f0eeead --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..e3530db --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu index 367e12b..fb7d89c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file From 8a6e78c5b22806e2fb7bbc48eb36db80416166b0 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 16:55:01 +0800 Subject: [PATCH 12/32] Adds mask and bias support to kernel generation Extends the flash attention kernel generator to support additional template parameters for masking and bias operations. Updates all kernel templates to include HAS_MASK and HAS_BIAS parameters, allowing for more flexible attention implementations with optional masking and bias addition. Modifies the kernel filename generation to include mask and bias flags for better organization and identification of generated kernel variants. Changes the default output directory to a more structured path and ensures directory creation before writing files. --- csrc/flash_dmattn/src/generate_kernels.py | 28 +++++++++++++++-------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/csrc/flash_dmattn/src/generate_kernels.py b/csrc/flash_dmattn/src/generate_kernels.py index 00cfc3b..54d5a72 100644 --- a/csrc/flash_dmattn/src/generate_kernels.py +++ b/csrc/flash_dmattn/src/generate_kernels.py @@ -12,6 +12,8 @@ SM = [80] # Sm80 kernels support up to HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256] IS_CAUSAL = ["false", "true"] +HAS_MASK = ["false", "true"] +HAS_BIAS = ["false", "true"] NAMESPACE_INCLUDE = '#include "namespace_config.h"\n' def get_fwd_template() -> str: @@ -21,8 +23,8 @@ def get_fwd_template() -> str: namespace FLASH_NAMESPACE {{ template<> -void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ - run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(params, stream); }} }} // namespace FLASH_NAMESPACE @@ -34,7 +36,7 @@ def get_fwd_split_template() -> str: namespace FLASH_NAMESPACE {{ -template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(Flash_fwd_params ¶ms, cudaStream_t stream); }} // namespace FLASH_NAMESPACE """.strip() @@ -46,8 +48,8 @@ def get_bwd_template() -> str: namespace FLASH_NAMESPACE {{ template<> -void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ - run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(params, stream); }} }} // namespace FLASH_NAMESPACE @@ -59,6 +61,8 @@ class Kernel: dtype: str head_dim: int is_causal: str + has_mask: str + has_bias: str direction: str @property @@ -72,17 +76,19 @@ def template(self) -> str: return template_func().format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, - IS_CAUSAL=self.is_causal + IS_CAUSAL=self.is_causal, + HAS_MASK=self.has_mask, + HAS_BIAS=self.has_bias ) @property def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" + return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}{'has_mask_' if self.has_mask == 'true' else ''}{'has_bias_' if self.has_bias == 'true' else ''}sm{self.sm}.cu" def get_all_kernels() -> Generator[Kernel, None, None]: for direction in ["fwd", "fwd_split", "bwd"]: - for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) + for dtype, head_dim, is_causal, has_mask, has_bias, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, HAS_MASK, HAS_BIAS, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, has_mask=has_mask, has_bias=has_bias, direction=direction) def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: prelude = """ @@ -99,6 +105,8 @@ def main(output_dir: Optional[str]) -> None: else: output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for kernel in get_all_kernels(): write_kernel(kernel, output_dir) @@ -110,7 +118,7 @@ def main(output_dir: Optional[str]) -> None: parser.add_argument( "-o", "--output_dir", - default="instantiations", + default="src/instantiations", required=False, help="Where to generate the kernels " " will default to the current directory ", From b3b6f80210334a42bba54249ca7b71716d501bec Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 16:55:40 +0800 Subject: [PATCH 13/32] Adds mask and bias support to MHA function templates Extends forward and backward multi-head attention function templates with Has_mask and Has_bias template parameters to enable conditional mask and bias functionality during attention computation. --- csrc/flash_dmattn/src/flash.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index d456de1..a1c9bf1 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -199,9 +199,9 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE From aa36f574217ca5bf6681914fde50752d0560776f Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 16:56:15 +0800 Subject: [PATCH 14/32] Adds mask and bias support to attention kernels Extends both forward and backward multi-head attention kernels to support additional mask and bias parameters through new template arguments. Enhances kernel flexibility by allowing attention mechanisms to handle custom masking patterns and bias terms beyond just causal masking. --- csrc/flash_dmattn/flash_api.cpp | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index b375268..fd20920 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -262,14 +262,18 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // splitkv kernel is not supported for head_dim >= 128 in sm89 due to smem limits - bool splitkv_forbidden = (kHeadDim >= 128) && (max_smem_per_block < 112 * 1024); - params.num_splits = splitkv_forbidden ? 1 : params.num_splits; - if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_splitkv_dispatch(params, stream); - } + BOOL_SWITCH(params.has_mask, Has_mask, [&] { + BOOL_SWITCH(params.has_bias, Has_bias, [&] { + // splitkv kernel is not supported for head_dim >= 128 in sm89 due to smem limits + bool splitkv_forbidden = (kHeadDim >= 128) && (max_smem_per_block < 112 * 1024); + params.num_splits = splitkv_forbidden ? 1 : params.num_splits; + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); }); }); }); @@ -799,7 +803,11 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_mha_bwd_(params, stream); + BOOL_SWITCH(params.has_mask, Has_mask, [&] { + BOOL_SWITCH(params.has_bias, Has_bias, [&] { + run_mha_bwd_(params, stream); + }); + }); }); }); }); From 1e51290c69b0aee733a6a0da17da1a40ab81e88f Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 16:57:56 +0800 Subject: [PATCH 15/32] Adds compile-time masking and bias optimizations Introduces template parameters to conditionally compile mask and bias operations, enabling performance optimizations when these features are not needed. Replaces runtime checks with constexpr conditions to eliminate unnecessary computations and memory accesses when mask or bias are disabled. Improves mask condition logic from float comparison to boolean evaluation for more reliable masking behavior. --- csrc/flash_dmattn/src/mask.h | 241 ++++++++++++++++++++++++++++------- 1 file changed, 196 insertions(+), 45 deletions(-) diff --git a/csrc/flash_dmattn/src/mask.h b/csrc/flash_dmattn/src/mask.h index ca4f43b..a1f6c72 100644 --- a/csrc/flash_dmattn/src/mask.h +++ b/csrc/flash_dmattn/src/mask.h @@ -11,7 +11,7 @@ namespace FLASH_NAMESPACE { using namespace cute; -template +template __forceinline__ __device__ void apply_mask( TensorType &tensor, MaskType &mask, @@ -25,29 +25,107 @@ __forceinline__ __device__ void apply_mask( ) { // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(TensorType::rank == 2, "Only support 2D Tensor"); - static_assert(MaskType::rank == 2, "Only support 2D Mask"); - static_assert(BiasType::rank == 2, "Only support 2D Bias"); + if constexpr (Has_mask) + static_assert(MaskType::rank == 2, "Only support 2D Mask"); + if constexpr (Has_bias) + static_assert(BiasType::rank == 2, "Only support 2D Bias"); + const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; + + if constexpr (Has_mask && Has_bias) { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias or masking + tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } + } + } + } + } else if constexpr (Has_mask && !Has_bias) { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling or masking + tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) + ? -INFINITY + : tensor(coord) * scale_softmax; + } + } + } + } + } else if constexpr (!Has_mask && Has_bias) { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias + tensor(coord) = (col_idx >= col_idx_limit) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } + } + } + } + } else { // !Has_mask && !Has_bias #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - // Without the "make_coord" we get wrong results - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling + tensor(coord) = (col_idx >= col_idx_limit) + ? -INFINITY + : tensor(coord) * scale_softmax; + } } } } @@ -65,53 +143,126 @@ struct Mask { , max_seqlen_q(max_seqlen_q) { }; - template + template __forceinline__ __device__ void apply_mask( TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) - MaskType &tSrMask, // Attention Mask (MMA=4, MMA_M, MMA_N) - BiasType &tSrBias, // Attention Bias (MMA=4, MMA_M, MMA_N) + const MaskType &mask_, // Attention Mask (MMA=4, MMA_M, MMA_N) + const BiasType &bias_, // Attention Bias (MMA=4, MMA_M, MMA_N) const float scale_softmax, // Scale for softmax const int col_idx_offset_, // Column index offset const int row_idx_offset, // Row index offset const int warp_row_stride // Warp row stride ) { static_assert(TensorType::rank == 3, "tensor_ must be 3D Tensor"); - static_assert(MaskType::rank == 3, "Mask must be 3D Tensor"); - static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); + if constexpr (Has_mask) + static_assert(MaskType::rank == 3, "mask_ must be 3D Tensor"); + if constexpr (Has_bias) + static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); - // const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); - // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; + + if constexpr (Has_mask && Has_bias) { + Tensor mask = make_tensor(mask_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(mask_.layout())); + Tensor bias = make_tensor(bias_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(bias_.layout())); #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias or masking + tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } } } } - } - + } else if constexpr (Has_mask && !Has_bias) { + Tensor mask = make_tensor(mask_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(mask_.layout())); + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling or masking + tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) + ? -INFINITY + : tensor(coord) * scale_softmax; + } + } + } + } + } else if constexpr (!Has_mask && Has_bias) { + Tensor bias = make_tensor(bias_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(bias_.layout())); + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias + tensor(coord) = (col_idx >= col_idx_limit) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } + } + } + } + } else { // !Has_mask && !Has_bias + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling + tensor(coord) = (col_idx >= col_idx_limit) + ? -INFINITY + : tensor(coord) * scale_softmax; + } + } + } + } + } } }; From fdda8b5cfb7a241e9cd30c025e2642f9118c2fb2 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 17:33:47 +0800 Subject: [PATCH 16/32] Adds mask and bias support to flash attention kernels Extends kernel templates with Has_mask and Has_bias boolean parameters to enable attention masking and bias functionality. Updates all kernel function signatures and template instantiations to accommodate the new parameters while maintaining backward compatibility. Removes commented debug code and consolidates CUDA function attribute setting for improved code clarity. --- .../src/flash_fwd_launch_template.h | 91 +++++++++---------- 1 file changed, 41 insertions(+), 50 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_launch_template.h b/csrc/flash_dmattn/src/flash_fwd_launch_template.h index 2a7dd4a..55fa2ff 100644 --- a/csrc/flash_dmattn/src/flash_fwd_launch_template.h +++ b/csrc/flash_dmattn/src/flash_fwd_launch_template.h @@ -30,17 +30,17 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -51,7 +51,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const size_t smem_size = Kernel_traits::kSmemSize; // printf("smem_size = %d (includes mask memory)\n", int(smem_size)); @@ -72,13 +72,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst)); - // auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( @@ -92,7 +88,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -template +template void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); @@ -106,14 +102,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.num_splits > 1, Split, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // printf("run_flash_splitkv_fwd: Split = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap)); + auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( @@ -152,15 +143,15 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } } -template +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions constexpr static int kBlockN = 64; // Fixed for all head dimensions // constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd, Is_causal>(params, stream); + run_flash_splitkv_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } -template +template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; int device; @@ -174,19 +165,19 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { } if (max_smem_per_block >= 164 * 1024) { // 28KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 8 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // 24KB, 4 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; int device; @@ -200,19 +191,19 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { } if (max_smem_per_block >= 164 * 1024) { // H100 and A100 // 40KB, 2 CTAs in sm86 and sm 89, 4 CTAs in A100, 5 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 112KB, N/A in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 32KB, 3 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; int device; @@ -226,18 +217,18 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { } if (max_smem_per_block >= 164 * 1024) { // H100 and A100 // 52KB, 1 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 136KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 40KB, 2 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; int device; @@ -251,29 +242,29 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { } if (max_smem_per_block >= 164 * 1024) { // H100 and A100 // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 96KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 160KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 48KB, 2 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 128KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 208KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } -template +template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; @@ -287,14 +278,14 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { } if (max_smem_per_block >= 112 * 1024) { // H100 and A100 // 112KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 192KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); // 256KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, N/A CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 80KB, 1 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } } From c8be594a2a0619d770476d7dd43553705c0b1a8e Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 17:37:02 +0800 Subject: [PATCH 17/32] Adds mask and bias support to flash attention backward kernels Extends backward kernel templates with Has_mask and Has_bias parameters to enable attention masking and bias functionality during gradient computation. Updates all kernel instantiations and function signatures to propagate the new template parameters through the call chain, maintaining consistency across different head dimensions and device configurations. Includes minor code formatting improvements for better readability. --- .../src/flash_bwd_launch_template.h | 85 ++++++++++--------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_launch_template.h b/csrc/flash_dmattn/src/flash_bwd_launch_template.h index 06a46fd..00712b8 100644 --- a/csrc/flash_dmattn/src/flash_bwd_launch_template.h +++ b/csrc/flash_dmattn/src/flash_bwd_launch_template.h @@ -31,17 +31,17 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Is_even_M, bool Is_even_K) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_dq_dk_dv(params); + FLASH_NAMESPACE::compute_dq_dk_dv(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -68,7 +68,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { FLASH_NAMESPACE::convert_dKV(params); } -template +template void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); @@ -98,11 +98,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -112,146 +110,151 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) auto kernel_dq = &flash_bwd_convert_dq_kernel; if (Kernel_traits::kSmemdQSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); } kernel_dq<<>>(params, !params.deterministic ? 1 : gridDimx); C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template +template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_BACKWARD - run_flash_bwd_seqk_parallel(params, stream); + run_flash_bwd_seqk_parallel(params, stream); #endif } -template +template void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 104 * 1024) { // H100 and A100 // 104KB, 1 CTAs in A100, 2 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 96KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 144 * 1024) { // H100 and A100 // In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close. // 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100. - // run_flash_bwd, Is_causal>(params, stream); + // run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); // 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. - // run_flash_bwd, Is_causal>(params, stream); + // run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); // 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 88KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times } -template +template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 116 * 1024) { // H100 and A100 // 116KB, 1 CTAs in A100, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 76KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 144 * 1024) { // H100 and A100 // 144KB, 1 CTAs in A100, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 80KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 136 * 1024) { // H100 and A100 // 136KB, 1 CTAs in A100, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 96KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 176 * 1024) { // H100 // 176KB, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else if (max_smem_per_block >= 144 * 1024) { // A100 // 144KB, 1 CTAs in A100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 96KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } From da8d7ed5227f9ebfeee03ab8b8cf49ccf6df50cc Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 17:37:37 +0800 Subject: [PATCH 18/32] Adds mask and bias template parameters to backward kernels Extends the template parameter list to include Has_mask and Has_bias flags for better flexibility in handling attention mechanisms with masks and biases. Updates all function calls to pass through the new template parameters while maintaining backward compatibility with existing functionality. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 05ec18f..07b8aab 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -76,7 +76,7 @@ CUTE_HOST_DEVICE auto make_tiled_copy_C_warpcontiguousN( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -1069,7 +1069,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. @@ -1083,20 +1083,20 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -1106,7 +1106,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } From 8ed228db43047925705e375f30f23b438ed9de6e Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 17:40:01 +0800 Subject: [PATCH 19/32] Adds mask and bias template parameters to attention kernels Extends kernel templates with Has_mask and Has_bias boolean parameters to support attention masking and bias operations. Updates all affected function signatures and call sites to maintain consistency across the attention computation pipeline. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 09e031d..f51c4f3 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -50,7 +50,7 @@ __forceinline__ __device__ auto get_lse_tile( return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -762,7 +762,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -1492,7 +1492,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1500,12 +1500,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // The block index for the head. const int bidh = blockIdx.z; - FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1514,7 +1514,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// From f3914322bfed788ec2a3539a71691ae02d71e007 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 18 Sep 2025 23:19:38 +0800 Subject: [PATCH 20/32] Adds support for 3D mask and bias tensors in attention Extends mask and bias tensor handling to accept 3-dimensional inputs by automatically expanding them to 4D format with a dummy seqlen_q dimension. Removes rigid shape validation checks that prevented flexible tensor dimensions and updates tensor creation logic to handle both 3D and 4D formats appropriately. Ensures backward pass correctly reduces the dummy dimension when original bias was 3D to maintain output shape consistency. --- csrc/flash_dmattn/flash_api.cpp | 104 +++++++++++++++----------------- 1 file changed, 49 insertions(+), 55 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index fd20920..d06e426 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -365,8 +365,8 @@ mha_fwd( at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k - std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k + std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float softmax_scale, bool is_causal, @@ -401,6 +401,10 @@ mha_fwd( TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); CHECK_DEVICE(mask); TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + if (mask.dim() == 3) { + // Add a dummy dimension for seqlen_q + mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1}); + } } else { mask = torch::empty({0}, opts); } @@ -411,6 +415,10 @@ mha_fwd( TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); CHECK_DEVICE(bias); TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + if (bias.dim() == 3) { + // Add a dummy dimension for seqlen_q + bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1}); + } } else { bias = torch::empty({0}, opts); } @@ -475,24 +483,6 @@ mha_fwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - if (has_mask) { - if (num_heads_mask == 1) { - CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_mask == num_heads_k) { - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); - } - } - if (has_bias) { - if (num_heads_bias == 1) { - CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); - } - } at::Tensor out; if (out_.has_value()) { @@ -819,14 +809,14 @@ mha_bwd( const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k - const std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k + const std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + const std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dbias_, // batch_size x {1|num_heads_k|num_heads} x seqlen_q x seqlen_k + std::optional &dbias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k const float softmax_scale, const bool is_causal, const float softcap, @@ -871,6 +861,10 @@ mha_bwd( TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); CHECK_DEVICE(mask); TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + if (mask.dim() == 3) { + // Add a dummy dimension for seqlen_q + mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1}); + } } else { mask = torch::empty({0}, opts); } @@ -881,6 +875,10 @@ mha_bwd( TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); CHECK_DEVICE(bias); TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + if (bias.dim() == 3) { + // Add a dummy dimension for seqlen_q + bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1}); + } } else { bias = torch::empty({0}, opts); } @@ -915,24 +913,6 @@ mha_bwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - if (has_mask) { - if (num_heads_mask == 1) { - CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_mask == num_heads_k) { - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); - } - } - if (has_bias) { - if (num_heads_bias == 1) { - CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); - } - } CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); @@ -970,20 +950,28 @@ mha_bwd( TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); CHECK_DEVICE(dbias); TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - if (num_heads_bias == 1) { - CHECK_SHAPE(dbias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(dbias, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(dbias, batch_size, num_heads, seqlen_q, seqlen_k); + if (dbias.dim() == 3) { + // Add a dummy dimension for seqlen_q + dbias = dbias.unsqueeze(2).expand({-1, -1, seqlen_q, -1}); } } else { - if (num_heads_bias == 1) { - dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); - } else if (num_heads_bias == num_heads_k) { - dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); + if (bias.dim() == 4) { + if (num_heads_bias == 1) { + dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); + } else if (num_heads_bias == num_heads_k) { + dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); + } else { + dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } } else { - dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + if (num_heads_bias == 1) { + dbias = torch::empty({batch_size, 1, seqlen_k}, opts); + } else if (num_heads_bias == num_heads_k) { + dbias = torch::empty({batch_size, num_heads_k, seqlen_k}, opts); + } else { + dbias = torch::empty({batch_size, num_heads, seqlen_k}, opts); + } + dbias = dbias.unsqueeze(2).expand({-1, -1, seqlen_q, -1}); } } } else { @@ -1071,9 +1059,15 @@ mha_bwd( at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } // For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads - if (has_bias && num_heads_bias != num_heads) { - at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}); - } + if (has_bias) { + if (num_heads_bias != num_heads) { + at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}); + } + if (bias_.value().dim() == 3) { + // Reduce the dummy dimension for seqlen_q + dbias = dbias.sum(2); + } + } return { dq, dk, dv, dbias, softmax_d }; } From b919c4341debc4060472283bcd6807c72e9a1f05 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 19 Sep 2025 18:27:57 +0800 Subject: [PATCH 21/32] Refactors tensor partitioning for mask and bias in attention computations --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index f51c4f3..3861dbe 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -300,8 +300,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -1000,8 +1000,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); From 266f5e6020a33ead58245ecf455716e4092c4d5d Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 19 Sep 2025 22:04:58 +0800 Subject: [PATCH 22/32] Adds optional mask and bias support to kernel traits Introduces template parameters to conditionally enable mask and bias functionality in flash attention kernels. Optimizes shared memory allocation by only reserving space for mask and bias when actually needed, reducing memory footprint when these features are disabled. --- csrc/flash_dmattn/src/kernel_traits.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index 8f648e1..f6a4e42 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -48,7 +48,7 @@ struct Flash_kernel_traits { // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true template< int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, - bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t, + bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, bool Has_mask_=false, bool Has_bias_=false, typename elem_type=cutlass::half_t, typename Base=Flash_kernel_traits > struct Flash_fwd_kernel_traits : public Base { @@ -61,6 +61,8 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + static constexpr bool Has_mask = Has_mask_; + static constexpr bool Has_bias = Has_bias_; // The number of threads. static constexpr int kNWarps = kNWarps_; @@ -153,7 +155,7 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr int kSmemBiasSize = size(SmemLayoutPS{}) * sizeof(Element); // Shared memory size with QKV matrices and mask/bias matrices - static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + kSmemMaskSize + kSmemBiasSize; + static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + (Has_mask ? kSmemMaskSize : 0) + (Has_bias ? kSmemBiasSize : 0); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); From e0b3d30f1757d818b40e8b60fa0aef9b05af9bcc Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 19 Sep 2025 22:08:46 +0800 Subject: [PATCH 23/32] Optimizes block size based on attention parameters Improves performance by using larger block sizes when masks and bias are not present. Uses adaptive block sizing strategy that considers head size to maximize throughput for cases without attention masks or bias terms. --- csrc/flash_dmattn/flash_api.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index d06e426..7224cac 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -335,8 +335,9 @@ std::tuple set_params_splitkv( ) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = 64; - // const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64); + const int block_n = params.has_mask || params.has_bias + ? 64 + : head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. From 7860c26708c2a8c00eb415776b49a250d244b1e0 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 19 Sep 2025 22:11:47 +0800 Subject: [PATCH 24/32] Optimizes kernel dispatch based on mask and bias flags Refactors flash attention kernel selection to use compile-time conditionals that specialize kernel configurations based on the presence of mask and bias operations. Updates block size calculations to use larger values when mask/bias are absent, improving performance for simpler attention patterns. Replaces runtime shared memory checks with more granular per-configuration memory thresholds, enabling better hardware utilization across different GPU architectures. --- .../src/flash_fwd_launch_template.h | 199 ++++++++++++------ 1 file changed, 138 insertions(+), 61 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_launch_template.h b/csrc/flash_dmattn/src/flash_fwd_launch_template.h index 55fa2ff..9c3d94b 100644 --- a/csrc/flash_dmattn/src/flash_fwd_launch_template.h +++ b/csrc/flash_dmattn/src/flash_fwd_launch_template.h @@ -146,9 +146,10 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions - constexpr static int kBlockN = 64; // Fixed for all head dimensions - // constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd, Is_causal, Has_mask, Has_bias>(params, stream); + constexpr static int kBlockN = Has_mask || Has_bias + ? 64 + : Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } template @@ -163,18 +164,24 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 164 * 1024) { - // 28KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 8 CTAs in H100. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 112 * 1024) { + // 28KB, 5 CTAs in A100, 8 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 24KB, 4 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + // 20KB, 5 CTAs in sm86 and sm 89, 8 CTAs in A100, 11 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else if constexpr (!Has_mask && Has_bias) { + // 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); } else { - // 24KB, 4 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); + // 24KB, 4 CTAs in sm86 and sm 89, 6 CTAs in A100, 9 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); } - } template @@ -189,18 +196,24 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 164 * 1024) { // H100 and A100 - // 40KB, 2 CTAs in sm86 and sm 89, 4 CTAs in A100, 5 CTAs in H100. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 112KB, N/A in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - } else { // sm86 and sm89 - // 32KB, 3 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - } - + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 160 * 1024) { + // 40KB, 4 CTAs in A100, 5 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 32KB, 3 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + // 32KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 7 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else if constexpr (!Has_mask && Has_bias) { + // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); + } } template @@ -215,16 +228,23 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 164 * 1024) { // H100 and A100 - // 52KB, 1 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 136KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - } else { // sm86 and sm89 - // 40KB, 2 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 156 * 1024) { + // 52KB, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 40KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + // 44KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 5 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else if constexpr (!Has_mask && Has_bias) { + // 44KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 5 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); } } @@ -240,28 +260,63 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 164 * 1024) { // H100 and A100 - // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 96KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 160KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - } else { // sm86 and sm89 - // 48KB, 2 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 128 * 1024) { + // 64KB, 2 CTAs in A100, 3 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 48KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + if (max_smem_per_block >= 112 * 1024) { + // 56KB, 2 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else { + // 40KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, false>(params, stream); + } + } else if constexpr (!Has_mask && Has_bias) { + // 80KB, 2 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + if (max_smem_per_block >= 128 * 1024) { + // 64KB, 2 CTAs in A100, 3 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); + return; + } else { + // 48KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, false, false>(params, stream); + return; + } } } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 128KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 208KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if constexpr (Has_mask && Has_bias) { + // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else if constexpr (Has_mask && !Has_bias) { + // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else if constexpr (!Has_mask && Has_bias) { + // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + // 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); + } } template @@ -276,16 +331,38 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 112 * 1024) { // H100 and A100 - // 112KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 192KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - // 256KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, N/A CTAs in H100. - // run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); - } else { // sm86 and sm89 - // 80KB, 1 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal, Has_mask, Has_bias>(params, stream); + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 112 * 1024) { + // 112KB, 1 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 80KB, 1 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + if (max_smem_per_block >= 104 * 1024) { + // 104KB, 1 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else { + // 72KB, 1 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, false>(params, stream); + } + } else if constexpr (!Has_mask && Has_bias) { + if (max_smem_per_block >= 104 * 1024) { + // 104KB, 1 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + // 72KB, 1 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, false, true>(params, stream); + } + } else { + if (max_smem_per_block >= 128 * 1024) { + // 128KB, 1 CTAs in A100, 1 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); + } else { + // 96KB, 1 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, false, false>(params, stream); + } } } From f00d1ff3a6f6a790e851f29976f79893ba7fde28 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Fri, 19 Sep 2025 22:14:37 +0800 Subject: [PATCH 25/32] Adds conditional compilation for mask and bias support Optimizes memory usage and performance by making mask and bias operations conditional based on template parameters Has_mask and Has_bias. Prevents unnecessary memory allocation and computation when mask or bias features are not needed, reducing shared memory footprint and eliminating redundant operations. Updates tensor creation logic to avoid allocating memory for unused mask/bias tensors and wraps all mask/bias related operations in compile-time conditionals. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 639 ++++++++++++++--------- 1 file changed, 405 insertions(+), 234 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 3861dbe..0e4d69b 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -217,11 +217,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); Tensor sMask = make_tensor( - sV.data() + size(sV), + Has_mask ? sV.data() + size(sV) : sV.data(), typename Kernel_traits::SmemLayoutAtomPS{} ); Tensor sBias = make_tensor( - sMask.data() + size(sMask), + Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(), typename Kernel_traits::SmemLayoutAtomPS{} ); @@ -324,10 +324,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Set predicates for n bounds if (!Is_even_MN) { - #pragma unroll - for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } - #pragma unroll - for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + if constexpr (Has_mask) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } + if constexpr (Has_bias) { + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } } @@ -359,24 +363,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); } - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask(_, _, _, n_block), tMasksMask, - any_active, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads + + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block), tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -386,15 +393,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias(_, _, _, n_block), tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block), tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } cute::cp_async_fence(); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } @@ -468,43 +477,75 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask/bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply mask and add bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + + // Scale attention scores and apply mask + mask.template apply_mask( + acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and add bias + mask.template apply_mask( + acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else { + // Scale attention scores only + mask.template apply_mask( + acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); } if (n_block > n_block_min) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block - 1), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask(_, _, _, n_block - 1), tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block - 1), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } if (any_active_next) { FLASH_NAMESPACE::copy( @@ -512,15 +553,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias(_, _, _, n_block - 1), tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block - 1), tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -597,43 +640,75 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply dynamic mask - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply mask and add bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + + // Scale attention scores and apply mask + mask.template apply_mask( + acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply bias + mask.template apply_mask( + acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else { + // Scale attention scores only + mask.template apply_mask( + acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); } if (n_block > n_block_min) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block - 1), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask(_, _, _, n_block - 1), tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block - 1), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads + } if (any_active_next) { FLASH_NAMESPACE::copy( @@ -641,15 +716,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias(_, _, _, n_block - 1), tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block - 1), tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -934,11 +1011,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); Tensor sMask = make_tensor( - sV.data() + size(sV), + Has_mask ? sV.data() + size(sV) : sV.data(), typename Kernel_traits::SmemLayoutAtomPS{} ); Tensor sBias = make_tensor( - sMask.data() + size(sMask), + Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(), typename Kernel_traits::SmemLayoutAtomPS{} ); @@ -1023,10 +1100,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Set predicates for n bounds if (!Is_even_MN) { - #pragma unroll - for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } - #pragma unroll - for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + if constexpr (Has_mask) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + } + if constexpr (Has_bias) { + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + } } // Prologue @@ -1043,24 +1124,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM ); - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -1070,15 +1153,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } cute::cp_async_fence(); } @@ -1158,19 +1243,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply dynamic mask - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply mask and bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + + // Scale attention scores and apply mask + mask.template apply_mask( + acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and add bias + mask.template apply_mask( + acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else { + // Scale attention scores only + mask.template apply_mask( + acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); @@ -1182,35 +1297,46 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gK, gMask, gBias if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN)); - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN)); + if constexpr (Has_mask) { + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN)); + } + if constexpr (Has_bias) { + tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN)); + } } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); - tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); + if constexpr (Has_mask) { + tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); + } + if constexpr (Has_bias) { + tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); + } + } + + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. } - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -1218,15 +1344,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -1307,19 +1435,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply dynamic mask - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply mask and bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + + // Scale attention scores and apply mask + mask.template apply_mask( + acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and add bias + mask.template apply_mask( + acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else { + // Scale attention scores only + mask.template apply_mask( + acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); @@ -1329,35 +1487,46 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gK, gMask, gBias if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN)); - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN)); + if constexpr (Has_mask) { + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN)); + } + if constexpr (Has_bias) { + tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN)); + } } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); - tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); + if constexpr (Has_mask) { + tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); + } + if constexpr (Has_bias) { + tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); + } + } + + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. } - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -1365,15 +1534,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); From 6ef3de8cf652f21716ce2929cbf459a17e813ada Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 00:22:42 +0800 Subject: [PATCH 26/32] Optimizes backward kernel with compile-time checks for mask and bias Replaces runtime conditionals with compile-time `constexpr` checks for mask and bias operations to improve performance by eliminating unnecessary computations when features are disabled. Reduces code duplication by creating specialized branches for different combinations of mask and bias availability, allowing the compiler to optimize out unused code paths. Eliminates redundant tensor copies and memory operations when mask or bias are not present, leading to better register usage and reduced memory bandwidth. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 262 ++++++++++++++--------- 1 file changed, 161 insertions(+), 101 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 07b8aab..780e616 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -422,10 +422,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Set predicates for n bounds if (!Is_even_MN) { - #pragma unroll - for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } - #pragma unroll - for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + if constexpr (Has_mask) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } + if constexpr (Has_bias) { + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } } @@ -568,24 +572,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, @@ -595,15 +601,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in ); if (any_active) { - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } } if (!Kernel_traits::Is_V_in_regs) { @@ -669,19 +677,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); - cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); - cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); - - // Reshape acc_s, mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - // if (cute::thread(32, 0)) { print(scores); } // Softcapping - calculating dTanh and scaling dS later with it [[maybe_unused]] Tensor dtanh = make_tensor_like(scores); @@ -689,21 +686,78 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond - // actual_seqlen_k, because acc_s would be some finite value for those indices. - // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, - // so the result would still be correct. - // However, it's possible that the values in acc_s are so large that they overflow - // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. - // So we need to mask out the elements beyond actual_seqlen_k. - FLASH_NAMESPACE::apply_mask( - scores, mask, bias, params.scale_softmax, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); + cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); + cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); + + // Reshape mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); + Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); + + // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond + // actual_seqlen_k, because acc_s would be some finite value for those indices. + // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, + // so the result would still be correct. + // However, it's possible that the values in acc_s are so large that they overflow + // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. + // So we need to mask out the elements beyond actual_seqlen_k. + FLASH_NAMESPACE::apply_mask( + scores, mask, bias, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); + cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); + + // Reshape mask from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); + + FLASH_NAMESPACE::apply_mask( + scores, mask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); + cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); + + // Reshape bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); + + FLASH_NAMESPACE::apply_mask( + scores, /*mask=*/nullptr, bias, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + } else { + FLASH_NAMESPACE::apply_mask( + scores, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + } // if (cute::thread(32, 0)) { print(scores); } // Compute the exponential value. @@ -796,16 +850,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_M, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); __syncthreads(); - // Write dS to dBias - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiassBias, tdBiasgdBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + // Write dS to dBias + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiassBias, tdBiasgdBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // if (cute::thread0()) { print(tPrP); } // Layout p_l = tPrP.layout(); @@ -827,26 +883,28 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } if (m_block > m_block_min) { - // Advance gMask - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - (m_block - 1) * kBlockM - // ); - // FLASH_NAMESPACE::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // Advance gMask + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - (m_block - 1) * kBlockM + // ); + // FLASH_NAMESPACE::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } // Advance gdO tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); @@ -944,19 +1002,21 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } if (m_block > m_block_min) { - // Advance gBias and gdBias - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); - tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); - if (any_active_next) { - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + // Advance gBias and gdBias + tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); + tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); + if (any_active_next) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } } } From 3884cfe54e99dc3982ce860806a2189cbdd48f73 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 00:22:54 +0800 Subject: [PATCH 27/32] Adds const qualifiers to mask and bias parameters Prevents accidental modification of mask and bias parameters in the apply_mask function by making them const references, improving code safety and expressing intent more clearly. --- csrc/flash_dmattn/src/mask.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/src/mask.h b/csrc/flash_dmattn/src/mask.h index a1f6c72..9f81cae 100644 --- a/csrc/flash_dmattn/src/mask.h +++ b/csrc/flash_dmattn/src/mask.h @@ -14,8 +14,8 @@ using namespace cute; template __forceinline__ __device__ void apply_mask( TensorType &tensor, - MaskType &mask, - BiasType &bias, + const MaskType &mask, + const BiasType &bias, const float scale_softmax, const int col_idx_offset_, const int max_seqlen_k, From ca66b6cec9c9beb1362819fbd2b487da4478d35a Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 00:42:38 +0800 Subject: [PATCH 28/32] Fixes bias gradient handling for 3D bias tensors Corrects the backward pass logic to properly handle bias tensors with missing sequence length dimension. Previously, 3D bias tensors were incorrectly expanded during gradient computation, leading to shape mismatches. Now properly detects when bias lacks the sequence length dimension and sums gradients across that dimension appropriately. Ensures gradient tensor is properly zeroed and handles both MQA/GQA cases and 3D bias tensors correctly. --- csrc/flash_dmattn/flash_api.cpp | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 7224cac..babe590 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -951,9 +951,10 @@ mha_bwd( TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); CHECK_DEVICE(dbias); TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - if (dbias.dim() == 3) { - // Add a dummy dimension for seqlen_q - dbias = dbias.unsqueeze(2).expand({-1, -1, seqlen_q, -1}); + if (dbias.dim() == 4) { + CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_k); } } else { if (bias.dim() == 4) { @@ -972,7 +973,6 @@ mha_bwd( } else { dbias = torch::empty({batch_size, num_heads, seqlen_k}, opts); } - dbias = dbias.unsqueeze(2).expand({-1, -1, seqlen_q, -1}); } } } else { @@ -1006,11 +1006,14 @@ mha_bwd( : dv; dbias_expanded = has_bias ? ( - num_heads_bias != num_heads // MQA / GQA + (num_heads_bias != num_heads) || (bias_.has_value() && bias_.value().dim() == 3) // MQA / GQA or bias has no seqlen_q dimension ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts) : dbias ) : torch::empty({0}, opts); + if (has_bias) { + dbias_expanded.zero_(); + } Flash_bwd_params params; @@ -1061,14 +1064,24 @@ mha_bwd( } // For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads if (has_bias) { + bool sum_seqlen_q = bias_.has_value() && bias_.value().dim() == 3; if (num_heads_bias != num_heads) { - at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}); + if (sum_seqlen_q) { + dbias_expanded = at::sum( + at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2} + ); + } else { + at::sum_out( + dbias, + at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2} + ); + } } - if (bias_.value().dim() == 3) { - // Reduce the dummy dimension for seqlen_q - dbias = dbias.sum(2); + if (sum_seqlen_q) { + // We need to sum across the seqlen_q dimension + at::sum_out(dbias, dbias_expanded, {2}); } - } + } return { dq, dk, dv, dbias, softmax_d }; } From a61db8cc6603a0087684070164ef1a7b1439de92 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 00:54:44 +0800 Subject: [PATCH 29/32] Refines documentation for attention_mask and attention_bias parameters in flash_dynamic_mask_attention_forward --- flash_dmattn/integrations/flash_dynamic_mask_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index 898950b..16631aa 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward( query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim). key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA. - attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA. + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len). + attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len), if attention_mask is None, also supports (batch_size, {num_heads|num_kv_heads|1}, key_len). scaling (Optional[float]): The scaling factor for the attention scores. softcap (Optional[float]): The softcap value for the attention scores. **kwargs: Additional keyword arguments. From 87ce7cc94e66b143730f1aa45f50dde083e91f02 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 00:56:33 +0800 Subject: [PATCH 30/32] Removes default initialization of attention_bias in _flash_dynamic_mask_attention_forward --- .../modeling_flash_dynamic_mask_attention_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index e3a8a3b..852a80d 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -80,9 +80,6 @@ def _flash_dynamic_mask_attention_forward( flash_kwargs["deterministic"] = deterministic if softcap is not None: flash_kwargs["softcap"] = softcap - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, num_kv_heads, query_length, key_length), dtype=dtype, device=query_states.device) query_states, key_states, value_states, attention_bias = fdma_peft_integration_check( query_states, key_states, value_states, attention_bias, target_dtype From 77edcb06c2e21abeaf5b4ab11911b6315e083219 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 00:56:50 +0800 Subject: [PATCH 31/32] Fixes bias tensor initialization in FlashDMAttnFunc to handle None case --- flash_dmattn/flash_dmattn_interface.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 9fa64d2..5b5a889 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -287,7 +287,8 @@ def backward( *args: Any, ): q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors - dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) + dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) + dbias = torch.zeros_like(bias) if bias is not None else None head_size_og = dout.size(3) dout_padded = dout From 96d6da04a900d9654f427ac8c6629ff6f2565b38 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 20 Sep 2025 00:57:48 +0800 Subject: [PATCH 32/32] Refactors CUDA extension sources in setup.py to use glob for dynamic file inclusion --- setup.py | 87 ++++---------------------------------------------------- 1 file changed, 6 insertions(+), 81 deletions(-) diff --git a/setup.py b/setup.py index 32e0936..cd325d8 100644 --- a/setup.py +++ b/setup.py @@ -206,87 +206,12 @@ def append_nvcc_threads(nvcc_extra_args): ext_modules.append( CUDAExtension( name="flash_dmattn_cuda", - sources=[ - "csrc/flash_dmattn/flash_api.cpp", - # Forward kernels - regular - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu", - # Forward kernels - causal - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu", - # Forward kernels - split - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu", - # Forward kernels - split causal - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu", - # Backward kernels - regular - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu", - # Backward kernels - causal - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu", - ], + sources=( + [ + "csrc/flash_dmattn/flash_api.cpp", + ] + + sorted(glob.glob("csrc/flash_dmattn/src/instantiations/flash_*.cu")) + ), extra_compile_args={ "cxx": compiler_c17_flag, "nvcc": append_nvcc_threads(nvcc_flags + cc_flag),