diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 4108cb4..babe590 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -50,6 +50,8 @@ void set_params_fprop( float softmax_scale, bool is_causal, const float softcap, + bool has_mask, + bool has_bias, bool seqlenq_ngroups_swapped=false, const bool unpadded_lse=false ) { @@ -63,34 +65,36 @@ void set_params_fprop( params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); - params.mask_ptr = mask.data_ptr(); - params.bias_ptr = bias.data_ptr(); + params.mask_ptr = has_mask ? mask.data_ptr() : nullptr; + params.bias_ptr = has_bias ? bias.data_ptr() : nullptr; params.o_ptr = out.data_ptr(); // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.mask_row_stride = mask.stride(-2); - params.bias_row_stride = bias.stride(-2); - params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); + params.k_row_stride = k.stride(-3); params.k_head_stride = k.stride(-2); + params.v_row_stride = v.stride(-3); params.v_head_stride = v.stride(-2); - params.mask_head_stride = mask.stride(-3); - params.bias_head_stride = bias.stride(-3); + params.mask_head_stride = has_mask ? mask.stride(-3) : 0; + params.mask_row_stride = has_mask ? mask.stride(-2) : 0; + params.bias_head_stride = has_bias ? bias.stride(-3) : 0; + params.bias_row_stride = has_bias ? bias.stride(-2) : 0; + params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = q.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); - params.mask_batch_stride = mask.stride(0); - params.bias_batch_stride = bias.stride(0); + params.mask_batch_stride = has_mask ? mask.stride(0) : 0; + params.bias_batch_stride = has_bias ? bias.stride(0) : 0; params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { - params.q_batch_stride *= seqlen_q; - params.o_batch_stride *= seqlen_q; + params.q_batch_stride *= seqlen_q; + params.mask_batch_stride *= seqlen_q; + params.bias_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; } } @@ -108,9 +112,11 @@ void set_params_fprop( params.b = b; params.h = h; params.h_k = h_k; - params.h_h_k_ratio = h / h_k; params.h_mask = h_mask; params.h_bias = h_bias; + params.h_h_k_ratio = h / h_k; + params.h_h_mask_ratio = h / h_mask; + params.h_h_bias_ratio = h / h_bias; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; @@ -134,6 +140,8 @@ void set_params_fprop( } params.is_causal = is_causal; + params.has_mask = has_mask; + params.has_bias = has_bias; params.is_seqlens_k_cumulative = true; #ifdef FLASHATTENTION_DISABLE_UNEVEN_K @@ -180,6 +188,8 @@ void set_params_dgrad( float softmax_scale, bool is_causal, const float softcap, + bool has_mask, + bool has_bias, bool deterministic, const bool unpadded_lse ) { @@ -195,33 +205,37 @@ void set_params_dgrad( softmax_scale, is_causal, softcap, + has_mask, + has_bias, false, // seqlenq_ngroups_swapped unpadded_lse ); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); - params.do_row_stride = dout.stride(-3); - params.do_head_stride = dout.stride(-2); params.dq_ptr = dq.data_ptr(); params.dk_ptr = dk.data_ptr(); params.dv_ptr = dv.data_ptr(); - params.dbias_ptr = dbias.data_ptr(); + params.dbias_ptr = has_bias ? dbias.data_ptr() : nullptr; + + // All stride are in elements, not bytes. + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); params.dq_row_stride = dq.stride(-3); - params.dk_row_stride = dk.stride(-3); - params.dv_row_stride = dv.stride(-3); - params.dbias_row_stride = dbias.stride(-2); params.dq_head_stride = dq.stride(-2); + params.dk_row_stride = dk.stride(-3); params.dk_head_stride = dk.stride(-2); + params.dv_row_stride = dv.stride(-3); params.dv_head_stride = dv.stride(-2); - params.dbias_head_stride = dbias.stride(-3); + params.dbias_head_stride = has_bias ? dbias.stride(-3) : 0; + params.dbias_row_stride = has_bias ? dbias.stride(-2) : 0; if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = dout.stride(0); params.dq_batch_stride = dq.stride(0); params.dk_batch_stride = dk.stride(0); params.dv_batch_stride = dv.stride(0); - params.dbias_batch_stride = dbias.stride(0); + params.dbias_batch_stride = has_bias ? dbias.stride(0) : 0; } params.dq_accum_ptr = dq_accum_d; @@ -248,14 +262,18 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // splitkv kernel is not supported for head_dim >= 128 in sm89 due to smem limits - bool splitkv_forbidden = (kHeadDim >= 128) && (max_smem_per_block < 112 * 1024); - params.num_splits = splitkv_forbidden ? 1 : params.num_splits; - if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_splitkv_dispatch(params, stream); - } + BOOL_SWITCH(params.has_mask, Has_mask, [&] { + BOOL_SWITCH(params.has_bias, Has_bias, [&] { + // splitkv kernel is not supported for head_dim >= 128 in sm89 due to smem limits + bool splitkv_forbidden = (kHeadDim >= 128) && (max_smem_per_block < 112 * 1024); + params.num_splits = splitkv_forbidden ? 1 : params.num_splits; + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); }); }); }); @@ -317,8 +335,9 @@ std::tuple set_params_splitkv( ) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = 64; - // const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64); + const int block_n = params.has_mask || params.has_bias + ? 64 + : head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. @@ -344,12 +363,12 @@ std::tuple set_params_splitkv( std::vector mha_fwd( - at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k - at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k, or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k - std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float softmax_scale, bool is_causal, const float softcap, @@ -367,16 +386,43 @@ mha_fwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); - TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + auto opts = q.options(); + + bool has_mask = mask_.has_value(); + at::Tensor mask; + if (has_mask) { + mask = mask_.value(); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); + CHECK_DEVICE(mask); + TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + if (mask.dim() == 3) { + // Add a dummy dimension for seqlen_q + mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1}); + } + } else { + mask = torch::empty({0}, opts); + } + bool has_bias = bias_.has_value(); + at::Tensor bias; + if (has_bias) { + bias = bias_.value(); + TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); + CHECK_DEVICE(bias); + TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + if (bias.dim() == 3) { + // Add a dummy dimension for seqlen_q + bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1}); + } + } else { + bias = torch::empty({0}, opts); + } const auto sizes = q.sizes(); @@ -386,14 +432,19 @@ mha_fwd( const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - int num_heads_mask = mask.size(1); - int num_heads_bias = bias.size(1); + int num_heads_mask = has_mask ? mask.size(1) : 1; + int num_heads_bias = has_bias ? bias.size(1) : 1; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); - TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + if (has_mask) { + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + } + if (has_bias) { + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + } // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } @@ -406,22 +457,26 @@ mha_fwd( const int orig_num_heads_bias = num_heads_bias; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); - if (num_heads_mask == 1) { - mask = mask.expand({batch_size, 1, ngroups, seqlen_k}); - } else if (num_heads_mask == num_heads_k) { - mask = mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}); - } else { // num_heads_mask == num_heads - mask = mask.reshape({batch_size, num_heads_k, ngroups, seqlen_k}); + if (has_mask) { + mask = num_heads_mask == 1 + ? mask.expand({batch_size, 1, ngroups, seqlen_k}) + : ( + num_heads_mask == num_heads_k + ? mask.expand({batch_size, num_heads_k, ngroups, seqlen_k}) + : mask.reshape({batch_size, num_heads_k, ngroups, seqlen_k}) + ); } - if (num_heads_bias == 1) { - bias = bias.expand({batch_size, 1, ngroups, seqlen_k}); - } else if (num_heads_bias == num_heads_k) { - bias = bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}); - } else { // num_heads_bias == num_heads - bias = bias.reshape({batch_size, num_heads_k, ngroups, seqlen_k}); + if (has_bias) { + bias = num_heads_bias == 1 + ? bias.expand({batch_size, 1, ngroups, seqlen_k}) + : ( + num_heads_bias == num_heads_k + ? bias.expand({batch_size, num_heads_k, ngroups, seqlen_k}) + : bias.reshape({batch_size, num_heads_k, ngroups, seqlen_k}) + ); } - num_heads_mask = (num_heads_mask == num_heads) ? num_heads_k : num_heads_mask; - num_heads_bias = (num_heads_bias == num_heads) ? num_heads_k : num_heads_bias; + num_heads_mask = has_mask ? ((num_heads_mask == num_heads) ? num_heads_k : num_heads_mask) : 1; + num_heads_bias = has_bias ? ((num_heads_bias == num_heads) ? num_heads_k : num_heads_bias) : 1; seqlen_q = ngroups; num_heads = num_heads_k; } @@ -429,20 +484,6 @@ mha_fwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - if (num_heads_mask == 1) { - CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_mask == num_heads_k) { - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); - } - if (num_heads_bias == 1) { - CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); - } at::Tensor out; if (out_.has_value()) { @@ -463,8 +504,6 @@ mha_fwd( const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - auto opts = q.options(); - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor p; @@ -490,7 +529,9 @@ mha_fwd( softmax_lse.data_ptr(), softmax_scale, is_causal, - softcap + softcap, + has_mask, + has_bias ); // Keep references to these tensors to extend their lifetime @@ -513,15 +554,15 @@ mha_fwd( out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); - if (orig_num_heads_mask == 1 || orig_num_heads_mask == num_heads_k) { - mask = mask.narrow(2, 0, 1); - } else { // orig_num_heads_mask == num_heads - mask = mask.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + if (has_mask) { + mask = (orig_num_heads_mask == 1 || orig_num_heads_mask == num_heads_k) + ? mask.narrow(2, 0, 1) + : mask.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); } - if (orig_num_heads_bias == 1 || orig_num_heads_bias == num_heads_k) { - bias = bias.narrow(2, 0, 1); - } else { // orig_num_heads_bias == num_heads - bias = bias.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); + if (has_bias) { + bias = (orig_num_heads_bias == 1 || orig_num_heads_bias == num_heads_k) + ? bias.narrow(2, 0, 1) + : bias.reshape({batch_size, num_heads_k * seqlen_q, 1, seqlen_k}); } } return {out, softmax_lse, p}; @@ -753,7 +794,11 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_mha_bwd_(params, stream); + BOOL_SWITCH(params.has_mask, Has_mask, [&] { + BOOL_SWITCH(params.has_bias, Has_bias, [&] { + run_mha_bwd_(params, stream); + }); + }); }); }); }); @@ -761,18 +806,18 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { std::vector mha_bwd( - const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) - const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k - const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x num_heads x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k - const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x seqlen_q - std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dbias_, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const std::optional &mask_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + const std::optional &bias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dbias_, // batch_size x {1|num_heads_k|num_heads} x {seqlen_q|0} x seqlen_k const float softmax_scale, const bool is_causal, const float softcap, @@ -796,39 +841,70 @@ mha_bwd( TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); - TORCH_CHECK(bias.dtype() == q_dtype, "query and bias must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - const auto sizes = q.sizes(); auto opts = q.options(); + bool has_mask = mask_.has_value(); + at::Tensor mask; + if (has_mask) { + mask = mask_.value(); + TORCH_CHECK(mask.dtype() == torch::kBool, "mask must have dtype bool"); + CHECK_DEVICE(mask); + TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + if (mask.dim() == 3) { + // Add a dummy dimension for seqlen_q + mask = mask.unsqueeze(2).expand({-1, -1, q.size(1), -1}); + } + } else { + mask = torch::empty({0}, opts); + } + bool has_bias = bias_.has_value(); + at::Tensor bias; + if (has_bias) { + bias = bias_.value(); + TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs"); + CHECK_DEVICE(bias); + TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + if (bias.dim() == 3) { + // Add a dummy dimension for seqlen_q + bias = bias.unsqueeze(2).expand({-1, -1, q.size(1), -1}); + } + } else { + bias = torch::empty({0}, opts); + } + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; const int seqlen_q = sizes[1]; const int num_heads = sizes[2]; const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - const int num_heads_mask = mask.size(1); - const int num_heads_bias = bias.size(1); + int num_heads_mask = has_mask ? mask.size(1) : 1; + int num_heads_bias = has_bias ? bias.size(1) : 1; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); - TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + if (has_mask) { + TORCH_CHECK(num_heads_mask == 1 || num_heads_mask == num_heads_k || num_heads_mask == num_heads, "Number of heads in mask must be 1, h_k or h"); + } + if (has_bias) { + TORCH_CHECK(num_heads_bias == 1 || num_heads_bias == num_heads_k || num_heads_bias == num_heads, "Number of heads in bias must be 1, h_k or h"); + } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); @@ -838,20 +914,6 @@ mha_bwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - if (num_heads_mask == 1) { - CHECK_SHAPE(mask, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_mask == num_heads_k) { - CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(mask, batch_size, num_heads, seqlen_q, seqlen_k); - } - if (num_heads_bias == 1) { - CHECK_SHAPE(bias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); - } else { - CHECK_SHAPE(bias, batch_size, num_heads, seqlen_q, seqlen_k); - } CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); @@ -883,26 +945,38 @@ mha_bwd( } else { dv = torch::empty_like(v); } - if (dbias_.has_value()) { - dbias = dbias_.value(); - TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); - CHECK_DEVICE(dbias); - TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); - if (num_heads_bias == 1) { - CHECK_SHAPE(dbias, batch_size, 1, seqlen_q, seqlen_k); - } else if (num_heads_bias == num_heads_k) { - CHECK_SHAPE(dbias, batch_size, num_heads_k, seqlen_q, seqlen_k); + if (has_bias) { + if (dbias_.has_value()) { + dbias = dbias_.value(); + TORCH_CHECK(dbias.dtype() == q_dtype, "dbias must have the same dtype as q"); + CHECK_DEVICE(dbias); + TORCH_CHECK(dbias.stride(-1) == 1, "dbias must have contiguous last dimension"); + if (dbias.dim() == 4) { + CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_q, seqlen_k); + } else { + CHECK_SHAPE(dbias, batch_size, num_heads_bias, seqlen_k); + } } else { - CHECK_SHAPE(dbias, batch_size, num_heads, seqlen_q, seqlen_k); + if (bias.dim() == 4) { + if (num_heads_bias == 1) { + dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); + } else if (num_heads_bias == num_heads_k) { + dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); + } else { + dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); + } + } else { + if (num_heads_bias == 1) { + dbias = torch::empty({batch_size, 1, seqlen_k}, opts); + } else if (num_heads_bias == num_heads_k) { + dbias = torch::empty({batch_size, num_heads_k, seqlen_k}, opts); + } else { + dbias = torch::empty({batch_size, num_heads, seqlen_k}, opts); + } + } } } else { - if (num_heads_bias == 1) { - dbias = torch::empty({batch_size, 1, seqlen_q, seqlen_k}, opts); - } else if (num_heads_bias == num_heads_k) { - dbias = torch::empty({batch_size, num_heads_k, seqlen_q, seqlen_k}, opts); - } else { - dbias = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); - } + dbias = torch::empty({0}, opts); } // bool loop = seqlen_k > blocksize_c; @@ -924,17 +998,21 @@ mha_bwd( } at::Tensor dk_expanded, dv_expanded, dbias_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - } - if (num_heads_bias != num_heads) { - dbias_expanded = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); - } else { - dbias_expanded = dbias; + dk_expanded = num_heads_k != num_heads // MQA / GQA + ? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts) + : dk; + dv_expanded = num_heads_k != num_heads // MQA / GQA + ? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts) + : dv; + dbias_expanded = has_bias + ? ( + (num_heads_bias != num_heads) || (bias_.has_value() && bias_.value().dim() == 3) // MQA / GQA or bias has no seqlen_q dimension + ? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts) + : dbias + ) + : torch::empty({0}, opts); + if (has_bias) { + dbias_expanded.zero_(); } Flash_bwd_params params; @@ -960,6 +1038,8 @@ mha_bwd( softmax_scale, is_causal, softcap, + has_mask, + has_bias, deterministic, /*unpadded_lse*/false ); @@ -983,8 +1063,24 @@ mha_bwd( at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } // For MQA/GQA or num_heads_bias != num_heads, we also need to sum dbias across the heads - if (num_heads_bias != num_heads) { - at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2}); + if (has_bias) { + bool sum_seqlen_q = bias_.has_value() && bias_.value().dim() == 3; + if (num_heads_bias != num_heads) { + if (sum_seqlen_q) { + dbias_expanded = at::sum( + at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2} + ); + } else { + at::sum_out( + dbias, + at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k}), {2} + ); + } + } + if (sum_seqlen_q) { + // We need to sum across the seqlen_q dimension + at::sum_out(dbias, dbias_expanded, {2}); + } } return { dq, dk, dv, dbias, softmax_d }; diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index a533cc3..a1c9bf1 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -38,13 +38,13 @@ struct QKV_params { int h, h_k; // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). - int h_h_k_ratio; // precompute h / h_k, + int h_h_k_ratio; // precompute h / h_k, }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Mask_params { - void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] + void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_mask_heads, query_len, key_len] // The stride of the attention mask tensors. index_t mask_batch_stride; // Stride between batches of attention mask @@ -53,12 +53,15 @@ struct Mask_params { // The number of heads in the mask. int h_mask; + int h_h_mask_ratio; // precompute h / h_mask + + bool has_mask; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Bias_params { - void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] + void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_bias_heads, query_len, key_len] // The stride of the attention bias tensor. index_t bias_batch_stride; // Stride between batches of attention bias @@ -67,13 +70,16 @@ struct Bias_params { // The number of heads in the bias. int h_bias; + int h_h_bias_ratio; // precompute h / h_bias + + bool has_bias; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { - // The O matrix (output). + // The O matrix. void * __restrict__ o_ptr; void * __restrict__ oaccum_ptr; @@ -90,7 +96,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, total_q, total_k; // The scaling factors for the kernel. float scale_softmax; @@ -105,6 +111,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; + // TODO: block mask for less memory usage int *__restrict__ blockmask; // The K_new and V_new matrices. @@ -192,9 +199,9 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index dc93b76..780e616 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -76,7 +76,7 @@ CUTE_HOST_DEVICE auto make_tiled_copy_C_warpcontiguousN( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -107,12 +107,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb) - + h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN; - const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); + + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN; const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb) - + h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN; + + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN; const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb) + bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) @@ -424,10 +422,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Set predicates for n bounds if (!Is_even_MN) { - #pragma unroll - for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } - #pragma unroll - for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + if constexpr (Has_mask) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } + if constexpr (Has_bias) { + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } } @@ -570,24 +572,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, @@ -597,15 +601,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in ); if (any_active) { - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } } if (!Kernel_traits::Is_V_in_regs) { @@ -671,19 +677,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); - cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); - cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); - - // Reshape acc_s, mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - // if (cute::thread(32, 0)) { print(scores); } // Softcapping - calculating dTanh and scaling dS later with it [[maybe_unused]] Tensor dtanh = make_tensor_like(scores); @@ -691,21 +686,78 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond - // actual_seqlen_k, because acc_s would be some finite value for those indices. - // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, - // so the result would still be correct. - // However, it's possible that the values in acc_s are so large that they overflow - // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. - // So we need to mask out the elements beyond actual_seqlen_k. - FLASH_NAMESPACE::apply_mask( - scores, mask, bias, params.scale_softmax, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); + cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); + cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); + + // Reshape mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); + Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); + + // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond + // actual_seqlen_k, because acc_s would be some finite value for those indices. + // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, + // so the result would still be correct. + // However, it's possible that the values in acc_s are so large that they overflow + // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. + // So we need to mask out the elements beyond actual_seqlen_k. + FLASH_NAMESPACE::apply_mask( + scores, mask, bias, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); + cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); + + // Reshape mask from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); + + FLASH_NAMESPACE::apply_mask( + scores, mask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); + cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); + + // Reshape bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) + Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); + + FLASH_NAMESPACE::apply_mask( + scores, /*mask=*/nullptr, bias, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + } else { + FLASH_NAMESPACE::apply_mask( + scores, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16 + ); + } // if (cute::thread(32, 0)) { print(scores); } // Compute the exponential value. @@ -798,16 +850,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom, AtomNum), MMA_M, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); __syncthreads(); - // Write dS to dBias - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiassBias, tdBiasgdBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + // Write dS to dBias + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiassBias, tdBiasgdBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // if (cute::thread0()) { print(tPrP); } // Layout p_l = tPrP.layout(); @@ -829,26 +883,28 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } if (m_block > m_block_min) { - // Advance gMask - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - (m_block - 1) * kBlockM - // ); - // FLASH_NAMESPACE::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // Advance gMask + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockM * params.mask_row_stride)); + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - (m_block - 1) * kBlockM + // ); + // FLASH_NAMESPACE::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } // Advance gdO tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); @@ -946,19 +1002,21 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } if (m_block > m_block_min) { - // Advance gBias and gdBias - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); - tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); - if (any_active_next) { - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - (m_block - 1) * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + // Advance gBias and gdBias + tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); + tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); + if (any_active_next) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - (m_block - 1) * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } } } @@ -1071,7 +1129,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. @@ -1085,20 +1143,20 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -1108,7 +1166,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/flash_dmattn/src/flash_bwd_launch_template.h b/csrc/flash_dmattn/src/flash_bwd_launch_template.h index 06a46fd..00712b8 100644 --- a/csrc/flash_dmattn/src/flash_bwd_launch_template.h +++ b/csrc/flash_dmattn/src/flash_bwd_launch_template.h @@ -31,17 +31,17 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Is_even_M, bool Is_even_K) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_M, bool Is_even_K) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_dq_dk_dv(params); + FLASH_NAMESPACE::compute_dq_dk_dv(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -68,7 +68,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { FLASH_NAMESPACE::convert_dKV(params); } -template +template void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); @@ -98,11 +98,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -112,146 +110,151 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) auto kernel_dq = &flash_bwd_convert_dq_kernel; if (Kernel_traits::kSmemdQSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); } kernel_dq<<>>(params, !params.deterministic ? 1 : gridDimx); C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template +template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_BACKWARD - run_flash_bwd_seqk_parallel(params, stream); + run_flash_bwd_seqk_parallel(params, stream); #endif } -template +template void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 104 * 1024) { // H100 and A100 // 104KB, 1 CTAs in A100, 2 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 96KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 144 * 1024) { // H100 and A100 // In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close. // 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100. - // run_flash_bwd, Is_causal>(params, stream); + // run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); // 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. - // run_flash_bwd, Is_causal>(params, stream); + // run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); // 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 88KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times } -template +template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 116 * 1024) { // H100 and A100 // 116KB, 1 CTAs in A100, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 76KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 144 * 1024) { // H100 and A100 // 144KB, 1 CTAs in A100, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 80KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 136 * 1024) { // H100 and A100 // 136KB, 1 CTAs in A100, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 96KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } -template +template void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( - &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } if (max_smem_per_block >= 176 * 1024) { // H100 // 176KB, 1 CTAs in H100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else if (max_smem_per_block >= 144 * 1024) { // A100 // 144KB, 1 CTAs in A100. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } else { // sm86 and sm89 // 96KB, 1 CTAs in sm86 and sm 89. - run_flash_bwd, Is_causal>(params, stream); + run_flash_bwd, Is_causal, Has_mask, Has_bias>(params, stream); } } diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 5b701bf..0e4d69b 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -50,7 +50,7 @@ __forceinline__ __device__ auto get_lse_tile( return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -136,8 +136,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // might save us 1 register (we just need n_block instead of both n_block and n_block_max). const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); - const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); // Global memory tensor configuration Tensor mQ = make_tensor( @@ -176,7 +174,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.mask_head_stride, params.mask_row_stride, _1{}) ); Tensor gMask = local_tile( - mMask(h_idx_mask, _, _), + mMask(bidh / params.h_h_mask_ratio, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) @@ -186,7 +184,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.bias_head_stride, params.bias_row_stride, _1{}) ); Tensor gBias = local_tile( - mBias(h_idx_bias, _, _), + mBias(bidh / params.h_h_bias_ratio, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) @@ -219,11 +217,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); Tensor sMask = make_tensor( - sV.data() + size(sV), + Has_mask ? sV.data() + size(sV) : sV.data(), typename Kernel_traits::SmemLayoutAtomPS{} ); Tensor sBias = make_tensor( - sMask.data() + size(sMask), + Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(), typename Kernel_traits::SmemLayoutAtomPS{} ); @@ -302,8 +300,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -326,10 +324,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Set predicates for n bounds if (!Is_even_MN) { - #pragma unroll - for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } - #pragma unroll - for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + if constexpr (Has_mask) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } + if constexpr (Has_bias) { + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < std::max(0, binfo.actual_seqlen_k - n_block * kBlockN); } + } } @@ -361,24 +363,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); } - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask(_, _, _, n_block), tMasksMask, - any_active, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads + + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block), tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -388,15 +393,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias(_, _, _, n_block), tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block), tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } cute::cp_async_fence(); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } @@ -470,43 +477,75 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask/bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply mask and add bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + + // Scale attention scores and apply mask + mask.template apply_mask( + acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and add bias + mask.template apply_mask( + acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else { + // Scale attention scores only + mask.template apply_mask( + acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); } if (n_block > n_block_min) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block - 1), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask(_, _, _, n_block - 1), tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block - 1), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } if (any_active_next) { FLASH_NAMESPACE::copy( @@ -514,15 +553,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias(_, _, _, n_block - 1), tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block - 1), tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -599,43 +640,75 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply dynamic mask - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply mask and add bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + + // Scale attention scores and apply mask + mask.template apply_mask( + acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply bias + mask.template apply_mask( + acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else { + // Scale attention scores only + mask.template apply_mask( + acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); } if (n_block > n_block_min) { - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask(_, _, _, n_block - 1), tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask(_, _, _, n_block - 1), tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask(_, _, _, n_block - 1), tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads + } if (any_active_next) { FLASH_NAMESPACE::copy( @@ -643,15 +716,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias(_, _, _, n_block - 1), tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block - 1), tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -764,7 +839,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -871,18 +946,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); const index_t col_offset_mask = (block_table == nullptr) ? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache) - + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN + + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache) - + h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; - const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh); + + (bidh / params.h_h_mask_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset; const index_t col_offset_bias = (block_table == nullptr) ? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache) - + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN + + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN : binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache) - + h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; + + (bidh / params.h_h_bias_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset; // Global memory tensor configuration Tensor mQ = make_tensor( @@ -938,11 +1011,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); Tensor sMask = make_tensor( - sV.data() + size(sV), + Has_mask ? sV.data() + size(sV) : sV.data(), typename Kernel_traits::SmemLayoutAtomPS{} ); Tensor sBias = make_tensor( - sMask.data() + size(sMask), + Has_bias ? (Has_mask ? sMask.data() + size(sMask) : sV.data() + size(sV)) : sV.data(), typename Kernel_traits::SmemLayoutAtomPS{} ); @@ -1004,8 +1077,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) - Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) - Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -1027,10 +1100,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Set predicates for n bounds if (!Is_even_MN) { - #pragma unroll - for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } - #pragma unroll - for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + if constexpr (Has_mask) { + #pragma unroll + for (int n = 0; n < size(tMaskpMask); ++n) { tMaskpMask(n) = get<1>(tMaskcMask(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + } + if constexpr (Has_bias) { + #pragma unroll + for (int n = 0; n < size(tBiaspBias); ++n) { tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN; } + } } // Prologue @@ -1047,24 +1124,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_q - m_block * kBlockM ); - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. + } // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. if (any_active) { @@ -1074,15 +1153,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } cute::cp_async_fence(); } @@ -1162,19 +1243,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply dynamic mask - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply mask and bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + + // Scale attention scores and apply mask + mask.template apply_mask( + acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and add bias + mask.template apply_mask( + acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else { + // Scale attention scores only + mask.template apply_mask( + acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); @@ -1186,35 +1297,46 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gK, gMask, gBias if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN)); - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN)); + if constexpr (Has_mask) { + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN)); + } + if constexpr (Has_bias) { + tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN)); + } } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); - tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); + if constexpr (Has_mask) { + tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); + } + if constexpr (Has_bias) { + tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); + } + } + + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. } - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -1222,15 +1344,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -1311,19 +1435,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply dynamic mask - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); + if constexpr (Has_mask && Has_bias) { + // Copy mask and bias from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and apply mask and bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (Has_mask && !Has_bias) { + // Copy mask from smem to registers + Tensor tSrMask = make_tensor(shape(acc_s)); + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + + // Scale attention scores and apply mask + mask.template apply_mask( + acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else if constexpr (!Has_mask && Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + + // Scale attention scores and add bias + mask.template apply_mask( + acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } else { + // Scale attention scores only + mask.template apply_mask( + acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); @@ -1333,35 +1487,46 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gK, gMask, gBias if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN)); - tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN)); + if constexpr (Has_mask) { + tMaskgMask.data() = tMaskgMask.data() + (-int(kBlockN)); + } + if constexpr (Has_bias) { + tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockN)); + } } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; - tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); - tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); + if constexpr (Has_mask) { + tMaskgMask.data() = tMaskgMask.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.mask_batch_stride + (block_table_offset_next - block_table_offset_cur); + } + if constexpr (Has_bias) { + tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.bias_batch_stride + (block_table_offset_next - block_table_offset_cur); + } + } + + if constexpr (Has_mask) { + // FLASH_NAMESPACE::copy_MN( + // gmem_tiled_copy_Mask, + // tMaskgMask, tMasksMask, + // tMaskcMask, tMaskpMask, + // binfo.actual_seqlen_q - m_block * kBlockM + // ); + // cute::cp_async_fence(); + // FLASH_NAMESPACE::cp_async_wait<0>(); + // // Do OR-reduce on the mask to see if any active threads for next iteration. + + FLASH_NAMESPACE::copy_mask_with_or_reduce( + gmem_tiled_copy_Mask, + tMaskgMask, tMasksMask, + any_active_next, + tMaskcMask, tMaskpMask, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // We don't need to syncthreads here because copy_mask is already or_syncthreads. } - // FLASH_NAMESPACE::copy_MN( - // gmem_tiled_copy_Mask, - // tMaskgMask, tMasksMask, - // tMaskcMask, tMaskpMask, - // binfo.actual_seqlen_q - m_block * kBlockM - // ); - // cute::cp_async_fence(); - // FLASH_NAMESPACE::cp_async_wait<0>(); - // // Do OR-reduce on the mask to see if any active threads for next iteration. - - FLASH_NAMESPACE::copy_mask_with_or_reduce( - gmem_tiled_copy_Mask, - tMaskgMask, tMasksMask, - any_active_next, - tMaskcMask, tMaskpMask, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // We don't need to syncthreads here because copy_mask is already or_syncthreads. if (any_active_next) { FLASH_NAMESPACE::copy( @@ -1369,15 +1534,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK, tKsK, tKVcKV, tKVpKV ); - FLASH_NAMESPACE::copy_bias( - gmem_tiled_copy_Bias, - tBiasgBias, tBiassBias, - tBiascBias, tBiaspBias, - binfo.actual_seqlen_q - m_block * kBlockM - ); - // Because copy_bias currently uses scalar loads, we need to sync here. - // TODO: Remove sync after fixing to vectorized loads. - __syncthreads(); + if constexpr (Has_bias) { + FLASH_NAMESPACE::copy_bias( + gmem_tiled_copy_Bias, + tBiasgBias, tBiassBias, + tBiascBias, tBiaspBias, + binfo.actual_seqlen_q - m_block * kBlockM + ); + // Because copy_bias currently uses scalar loads, we need to sync here. + // TODO: Remove sync after fixing to vectorized loads. + __syncthreads(); + } // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -1496,7 +1663,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1504,12 +1671,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // The block index for the head. const int bidh = blockIdx.z; - FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1518,7 +1685,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_dmattn/src/flash_fwd_launch_template.h b/csrc/flash_dmattn/src/flash_fwd_launch_template.h index 2a7dd4a..9c3d94b 100644 --- a/csrc/flash_dmattn/src/flash_fwd_launch_template.h +++ b/csrc/flash_dmattn/src/flash_fwd_launch_template.h @@ -30,17 +30,17 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Has_mask, bool Has_bias, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -51,7 +51,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const size_t smem_size = Kernel_traits::kSmemSize; // printf("smem_size = %d (includes mask memory)\n", int(smem_size)); @@ -72,13 +72,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst)); - // auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( @@ -92,7 +88,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -template +template void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); @@ -106,14 +102,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.num_splits > 1, Split, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // printf("run_flash_splitkv_fwd: Split = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap)); + auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( @@ -152,15 +143,16 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } } -template +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions - constexpr static int kBlockN = 64; // Fixed for all head dimensions - // constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd, Is_causal>(params, stream); + constexpr static int kBlockN = Has_mask || Has_bias + ? 64 + : Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd, Is_causal, Has_mask, Has_bias>(params, stream); } -template +template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; int device; @@ -172,21 +164,27 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 164 * 1024) { - // 28KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 8 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); - // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 112 * 1024) { + // 28KB, 5 CTAs in A100, 8 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 24KB, 4 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + // 20KB, 5 CTAs in sm86 and sm 89, 8 CTAs in A100, 11 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else if constexpr (!Has_mask && Has_bias) { + // 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); } else { - // 24KB, 4 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + // 24KB, 4 CTAs in sm86 and sm 89, 6 CTAs in A100, 9 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); } - } -template +template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; int device; @@ -198,21 +196,27 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 164 * 1024) { // H100 and A100 - // 40KB, 2 CTAs in sm86 and sm 89, 4 CTAs in A100, 5 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); - // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - // 112KB, N/A in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - } else { // sm86 and sm89 - // 32KB, 3 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); - } - + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 160 * 1024) { + // 40KB, 4 CTAs in A100, 5 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 32KB, 3 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + // 32KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 7 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else if constexpr (!Has_mask && Has_bias) { + // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); + } } -template +template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; int device; @@ -224,20 +228,27 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 164 * 1024) { // H100 and A100 - // 52KB, 1 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); - // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - // 136KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - } else { // sm86 and sm89 - // 40KB, 2 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 156 * 1024) { + // 52KB, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 40KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + // 44KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 5 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else if constexpr (!Has_mask && Has_bias) { + // 44KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 5 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + // 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); } } -template +template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; int device; @@ -249,31 +260,66 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 164 * 1024) { // H100 and A100 - // 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); - // 96KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - // 160KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - } else { // sm86 and sm89 - // 48KB, 2 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 128 * 1024) { + // 64KB, 2 CTAs in A100, 3 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 48KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + if (max_smem_per_block >= 112 * 1024) { + // 56KB, 2 CTAs in A100, 4 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else { + // 40KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, false>(params, stream); + } + } else if constexpr (!Has_mask && Has_bias) { + // 80KB, 2 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + if (max_smem_per_block >= 128 * 1024) { + // 64KB, 2 CTAs in A100, 3 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); + return; + } else { + // 48KB, 2 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, false, false>(params, stream); + return; + } } } -template +template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); - // 128KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - // 208KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device + ); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + if constexpr (Has_mask && Has_bias) { + // 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else if constexpr (Has_mask && !Has_bias) { + // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else if constexpr (!Has_mask && Has_bias) { + // 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + // 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); + } } -template +template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; @@ -285,16 +331,38 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - if (max_smem_per_block >= 112 * 1024) { // H100 and A100 - // 112KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100. - run_flash_fwd, Is_causal>(params, stream); - // 192KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - // 256KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, N/A CTAs in H100. - // run_flash_fwd, Is_causal>(params, stream); - } else { // sm86 and sm89 - // 80KB, 1 CTAs in sm86 and sm 89. - run_flash_fwd, Is_causal>(params, stream); + if constexpr (Has_mask && Has_bias) { + if (max_smem_per_block >= 112 * 1024) { + // 112KB, 1 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, true, true>(params, stream); + } else { + // 80KB, 1 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, true>(params, stream); + } + } else if constexpr (Has_mask && !Has_bias) { + if (max_smem_per_block >= 104 * 1024) { + // 104KB, 1 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, true, false>(params, stream); + } else { + // 72KB, 1 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, true, false>(params, stream); + } + } else if constexpr (!Has_mask && Has_bias) { + if (max_smem_per_block >= 104 * 1024) { + // 104KB, 1 CTAs in A100, 2 CTAs in H100. + run_flash_fwd, Is_causal, false, true>(params, stream); + } else { + // 72KB, 1 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, false, true>(params, stream); + } + } else { + if (max_smem_per_block >= 128 * 1024) { + // 128KB, 1 CTAs in A100, 1 CTAs in H100. + run_flash_fwd, Is_causal, false, false>(params, stream); + } else { + // 96KB, 1 CTAs in sm86 and sm 89. + run_flash_fwd, Is_causal, false, false>(params, stream); + } } } diff --git a/csrc/flash_dmattn/src/generate_kernels.py b/csrc/flash_dmattn/src/generate_kernels.py index 00cfc3b..54d5a72 100644 --- a/csrc/flash_dmattn/src/generate_kernels.py +++ b/csrc/flash_dmattn/src/generate_kernels.py @@ -12,6 +12,8 @@ SM = [80] # Sm80 kernels support up to HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256] IS_CAUSAL = ["false", "true"] +HAS_MASK = ["false", "true"] +HAS_BIAS = ["false", "true"] NAMESPACE_INCLUDE = '#include "namespace_config.h"\n' def get_fwd_template() -> str: @@ -21,8 +23,8 @@ def get_fwd_template() -> str: namespace FLASH_NAMESPACE {{ template<> -void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ - run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(params, stream); }} }} // namespace FLASH_NAMESPACE @@ -34,7 +36,7 @@ def get_fwd_split_template() -> str: namespace FLASH_NAMESPACE {{ -template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(Flash_fwd_params ¶ms, cudaStream_t stream); }} // namespace FLASH_NAMESPACE """.strip() @@ -46,8 +48,8 @@ def get_bwd_template() -> str: namespace FLASH_NAMESPACE {{ template<> -void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ - run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}, {HAS_MASK}, {HAS_BIAS}>(params, stream); }} }} // namespace FLASH_NAMESPACE @@ -59,6 +61,8 @@ class Kernel: dtype: str head_dim: int is_causal: str + has_mask: str + has_bias: str direction: str @property @@ -72,17 +76,19 @@ def template(self) -> str: return template_func().format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, - IS_CAUSAL=self.is_causal + IS_CAUSAL=self.is_causal, + HAS_MASK=self.has_mask, + HAS_BIAS=self.has_bias ) @property def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" + return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}{'has_mask_' if self.has_mask == 'true' else ''}{'has_bias_' if self.has_bias == 'true' else ''}sm{self.sm}.cu" def get_all_kernels() -> Generator[Kernel, None, None]: for direction in ["fwd", "fwd_split", "bwd"]: - for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) + for dtype, head_dim, is_causal, has_mask, has_bias, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, HAS_MASK, HAS_BIAS, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, has_mask=has_mask, has_bias=has_bias, direction=direction) def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: prelude = """ @@ -99,6 +105,8 @@ def main(output_dir: Optional[str]) -> None: else: output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + for kernel in get_all_kernels(): write_kernel(kernel, output_dir) @@ -110,7 +118,7 @@ def main(output_dir: Optional[str]) -> None: parser.add_argument( "-o", "--output_dir", - default="instantiations", + default="src/instantiations", required=False, help="Where to generate the kernels " " will default to the current directory ", diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..5e0883e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..31805fb --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..3851a70 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu index 581443c..5f928d3 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..1e148e9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..5ac4e19 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..a753d58 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu index eae9439..d1f0c9f 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..346c75e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..97c2016 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..b97517c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu index ad5d580..282c744 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..2ca811c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..5c8b491 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..672864c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu index 8d93480..1a2a991 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim128(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..cd8fdf3 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..d9625d7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..748ce7a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu index cee3073..1e3ed61 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..d159662 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..91fddab --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..5dc1bc9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu index bb5063b..0868295 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..38d1dc9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..2de0e95 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..4358751 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu index 290187a..de93c33 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..71d6a6e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..d39794d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..b19f9e7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu index 7bef7a3..3825325 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim192(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..be3e201 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..e35bf30 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..852286c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu index a28c41c..0a54ab1 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..6d27638 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..7896955 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..7cb82c6 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu index 114faf8..3c41f3c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..05644d3 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..07713a8 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..c64b79e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu index a89a253..1120012 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..4a4d8aa --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..84f8f1f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..b6e0e97 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu index db59281..84d4bfb 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim256(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..ec7f221 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..6be8d6b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..f5b0daa --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu index a0cde4f..73c7a18 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..0a757ff --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..14717b5 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..e01a354 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu index 29fb6f8..b82b8b1 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..6b7369b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ba1236b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..db9b308 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu index 885067d..c49252e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..8a4022f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..6f5537a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..be0e3c4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu index 2c35ae2..f2c9f4f 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim32(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..36a39e9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..4e735d0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..c332c96 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu index 5dc4084..502c429 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..367a798 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..000cf37 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..4f1cd88 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu index f410065..d959575 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..4c43cd4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..662df9b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..d5be139 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu index d58a0f6..a5a68b7 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..45eb04e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0e47df2 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..b082ab7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu index 47c01c2..2a3f1ee 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim64(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..717c2d1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0d64a82 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..3296e56 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu index 4c27c0b..1276b88 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..2d30d81 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..f90667d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..15bab62 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu index da96a97..75ba3ee 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..ed84e18 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ce4da33 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..626aac8 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu index 8e79520..7a2a20e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..80fec3d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..bc85d1d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..24ef408 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu index e1ac513..1a73dba 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { - run_mha_bwd_hdim96(params, stream); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..137a44d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..3126e9f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..5ceab71 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu index 2bacfd3..e6a99f1 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..6f00490 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..854c601 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..c137e94 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu index bc6f103..224dd9e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..b7afdcb --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0e8dd9a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..c3d2cda --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu index 48c2d89..14b0091 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..ef80d33 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..c749384 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..82ec511 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu index c67fdce..1c75613 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..e78791d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..05ac29c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..a8c060a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu index e957ce0..457c6ae 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..3467402 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..63172d0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..57d7b81 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu index f53f7f5..5e67bc9 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..35bdefd --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..4b3eaa7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..42f64ee --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu index 805e548..814944a 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..55c5e60 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..eea672d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..e4e0688 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu index 87a6565..3185d87 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..78d4313 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..c7e1c19 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..d67103b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu index ded9ab1..06d28b4 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..8df713d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..abae76d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..a66fb04 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu index c3d9e04..c6bc287 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..1b00dd1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..66ed354 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..6ae0a42 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu index 8780ed4..1b6ad94 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..6d17fa8 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..fcfe162 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..42bebc8 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu index 293e001..48d732e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..7ca3f9a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ab5a247 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..5b5df80 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu index f50910b..d94db09 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..bb16aed --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..f31d044 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..69fb705 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu index 6127386..7c4d01c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..664d742 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..6084e18 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..9228bad --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu index 845ab35..5bb2263 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..9bd62f6 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0c8c28e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..b91ecd2 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu index 94a1301..0fe7203 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..71277e6 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..a54a945 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..6a40c9a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu index f3d4e94..06aa7f2 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..b618f90 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..9b2dec0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..3b65d53 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu index a11cc79..b48cca5 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..708ddce --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..c935579 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..71e686c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu index 1681867..4b69c71 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..ca63d55 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..508e38f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..61b9621 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu index 7b7d184..2ce28f2 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..c2ad7c2 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..d2f46b9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..b605b86 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu index ad62e87..2425d20 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..6b34039 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..bd6ee43 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..663ba2b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu index 92c4344..fca0d6b 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..a8dc858 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ede9c0d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..4eb2c7f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu index 36db0da..04ceeff 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..4b98bf4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..1417fa2 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..cd07586 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_has_mask_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu index 50040b5..be4617c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu @@ -1,14 +1,13 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..5c31681 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..0344fe1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..f955761 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu index b4118c7..e53a721 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..9bc8f83 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..7cd1336 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..9a24313 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu index cdcfbe9..20e665b 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..198fc1d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..08a187a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..f8bca05 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu index 71e415d..69f7562 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..fd305fe --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..9a144f5 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..3cb985a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu index df4febe..2787367 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..c2ec067 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..82732fc --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..4155c6f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu index 83c8f8a..b36962b 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..307a1a1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ae704e7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..e1fe8a5 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu index d3bbf47..71119c0 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..2adbadc --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..154a54b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..ccf5bc3 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu index 5652982..18316d4 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..fd441c9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..a56610d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..54430a4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu index edac2b6..aca36b9 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..dd3c5ed --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..eef7f06 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..6b15c57 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu index 28ab7ad..6303b8c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..28305b9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..4e98cb7 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..e60cd70 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu index 751035e..5bbac09 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..63a0e75 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ba80503 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..0c5168d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu index 502b5cc..aa42659 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..5a70fbc --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..6d539d1 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..008478a --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu index 3153e17..625ba34 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..2cfd687 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..3eb807c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..8b36fb4 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu index 9910f63..8f461b2 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..e144b49 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..3dd585f --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..6a2eb0b --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu index d498fea..8f6e82d 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..07c8228 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..2e0b252 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..2772cca --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu index a5a713a..f48c612 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..1adfa91 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..ee443da --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..d4d4b85 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu index 4cfc36a..a9ac55e 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..f25e3d3 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..574bc25 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..c0e431e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu index e89cb76..cca29c5 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..9c676d0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..2797f2d --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..83d5a04 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu index 8d72e93..48ff397 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..6353a34 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..8871c82 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..3d54aad --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu index 76ba0c8..d030ef0 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..72d4513 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..2b7be7e --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..35c2845 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu index ab07719..4211c14 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..0989b27 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..b09b986 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..6bf4d32 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu index 8d44e28..7469a22 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu new file mode 100644 index 0000000..33381b0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..7f5d4c0 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu new file mode 100644 index 0000000..4d52f62 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu index 252b468..787308d 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu new file mode 100644 index 0000000..2e49991 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..e6c856c --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu new file mode 100644 index 0000000..b81d1a9 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu index 3eb97b7..1d3a8af 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu new file mode 100644 index 0000000..68d3a41 --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu new file mode 100644 index 0000000..f0eeead --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_has_bias_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu new file mode 100644 index 0000000..e3530db --- /dev/null +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_has_mask_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu index 367e12b..fb7d89c 100644 --- a/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu +++ b/csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu @@ -1,11 +1,10 @@ // Copyright (c) 2025, Jingze Shi and Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. -// This file is auto-generated. See "generate_kernels.py" -#include "namespace_config.h" +// This file is auto-generated. See "generate_kernels.py"#include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/csrc/flash_dmattn/src/kernel_traits.h b/csrc/flash_dmattn/src/kernel_traits.h index 8f648e1..f6a4e42 100644 --- a/csrc/flash_dmattn/src/kernel_traits.h +++ b/csrc/flash_dmattn/src/kernel_traits.h @@ -48,7 +48,7 @@ struct Flash_kernel_traits { // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true template< int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, - bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t, + bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, bool Has_mask_=false, bool Has_bias_=false, typename elem_type=cutlass::half_t, typename Base=Flash_kernel_traits > struct Flash_fwd_kernel_traits : public Base { @@ -61,6 +61,8 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + static constexpr bool Has_mask = Has_mask_; + static constexpr bool Has_bias = Has_bias_; // The number of threads. static constexpr int kNWarps = kNWarps_; @@ -153,7 +155,7 @@ struct Flash_fwd_kernel_traits : public Base { static constexpr int kSmemBiasSize = size(SmemLayoutPS{}) * sizeof(Element); // Shared memory size with QKV matrices and mask/bias matrices - static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + kSmemMaskSize + kSmemBiasSize; + static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + (Has_mask ? kSmemMaskSize : 0) + (Has_bias ? kSmemBiasSize : 0); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); diff --git a/csrc/flash_dmattn/src/mask.h b/csrc/flash_dmattn/src/mask.h index ca4f43b..9f81cae 100644 --- a/csrc/flash_dmattn/src/mask.h +++ b/csrc/flash_dmattn/src/mask.h @@ -11,11 +11,11 @@ namespace FLASH_NAMESPACE { using namespace cute; -template +template __forceinline__ __device__ void apply_mask( TensorType &tensor, - MaskType &mask, - BiasType &bias, + const MaskType &mask, + const BiasType &bias, const float scale_softmax, const int col_idx_offset_, const int max_seqlen_k, @@ -25,29 +25,107 @@ __forceinline__ __device__ void apply_mask( ) { // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(TensorType::rank == 2, "Only support 2D Tensor"); - static_assert(MaskType::rank == 2, "Only support 2D Mask"); - static_assert(BiasType::rank == 2, "Only support 2D Bias"); + if constexpr (Has_mask) + static_assert(MaskType::rank == 2, "Only support 2D Mask"); + if constexpr (Has_bias) + static_assert(BiasType::rank == 2, "Only support 2D Bias"); + const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; + + if constexpr (Has_mask && Has_bias) { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias or masking + tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } + } + } + } + } else if constexpr (Has_mask && !Has_bias) { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling or masking + tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) + ? -INFINITY + : tensor(coord) * scale_softmax; + } + } + } + } + } else if constexpr (!Has_mask && Has_bias) { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias + tensor(coord) = (col_idx >= col_idx_limit) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } + } + } + } + } else { // !Has_mask && !Has_bias #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - // Without the "make_coord" we get wrong results - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + // Without the "make_coord" we get wrong results + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling + tensor(coord) = (col_idx >= col_idx_limit) + ? -INFINITY + : tensor(coord) * scale_softmax; + } } } } @@ -65,53 +143,126 @@ struct Mask { , max_seqlen_q(max_seqlen_q) { }; - template + template __forceinline__ __device__ void apply_mask( TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) - MaskType &tSrMask, // Attention Mask (MMA=4, MMA_M, MMA_N) - BiasType &tSrBias, // Attention Bias (MMA=4, MMA_M, MMA_N) + const MaskType &mask_, // Attention Mask (MMA=4, MMA_M, MMA_N) + const BiasType &bias_, // Attention Bias (MMA=4, MMA_M, MMA_N) const float scale_softmax, // Scale for softmax const int col_idx_offset_, // Column index offset const int row_idx_offset, // Row index offset const int warp_row_stride // Warp row stride ) { static_assert(TensorType::rank == 3, "tensor_ must be 3D Tensor"); - static_assert(MaskType::rank == 3, "Mask must be 3D Tensor"); - static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); + if constexpr (Has_mask) + static_assert(MaskType::rank == 3, "mask_ must be 3D Tensor"); + if constexpr (Has_bias) + static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); - // const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); - // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; + + if constexpr (Has_mask && Has_bias) { + Tensor mask = make_tensor(mask_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(mask_.layout())); + Tensor bias = make_tensor(bias_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(bias_.layout())); #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias or masking + tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } } } } - } - + } else if constexpr (Has_mask && !Has_bias) { + Tensor mask = make_tensor(mask_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(mask_.layout())); + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling or masking + tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) + ? -INFINITY + : tensor(coord) * scale_softmax; + } + } + } + } + } else if constexpr (!Has_mask && Has_bias) { + Tensor bias = make_tensor(bias_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(bias_.layout())); + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias + tensor(coord) = (col_idx >= col_idx_limit) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } + } + } + } + } else { // !Has_mask && !Has_bias + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling + tensor(coord) = (col_idx >= col_idx_limit) + ? -INFINITY + : tensor(coord) * scale_softmax; + } + } + } + } + } } }; diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index 9fa64d2..5b5a889 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -287,7 +287,8 @@ def backward( *args: Any, ): q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors - dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias) + dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) + dbias = torch.zeros_like(bias) if bias is not None else None head_size_og = dout.size(3) dout_padded = dout diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index 898950b..16631aa 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward( query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim). key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA. - attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA. + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len). + attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len), if attention_mask is None, also supports (batch_size, {num_heads|num_kv_heads|1}, key_len). scaling (Optional[float]): The scaling factor for the attention scores. softcap (Optional[float]): The softcap value for the attention scores. **kwargs: Additional keyword arguments. diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index e3a8a3b..852a80d 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -80,9 +80,6 @@ def _flash_dynamic_mask_attention_forward( flash_kwargs["deterministic"] = deterministic if softcap is not None: flash_kwargs["softcap"] = softcap - - if attention_bias is None: - attention_bias = torch.zeros((batch_size, num_kv_heads, query_length, key_length), dtype=dtype, device=query_states.device) query_states, key_states, value_states, attention_bias = fdma_peft_integration_check( query_states, key_states, value_states, attention_bias, target_dtype diff --git a/setup.py b/setup.py index 32e0936..cd325d8 100644 --- a/setup.py +++ b/setup.py @@ -206,87 +206,12 @@ def append_nvcc_threads(nvcc_extra_args): ext_modules.append( CUDAExtension( name="flash_dmattn_cuda", - sources=[ - "csrc/flash_dmattn/flash_api.cpp", - # Forward kernels - regular - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_sm80.cu", - # Forward kernels - causal - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim32_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim64_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim96_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim128_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim192_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_hdim256_bf16_causal_sm80.cu", - # Forward kernels - split - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_sm80.cu", - # Forward kernels - split causal - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim32_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim64_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim96_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim128_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim192_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_fwd_split_hdim256_bf16_causal_sm80.cu", - # Backward kernels - regular - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_sm80.cu", - # Backward kernels - causal - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim32_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim64_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim96_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim128_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim192_bf16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_fp16_causal_sm80.cu", - "csrc/flash_dmattn/src/instantiations/flash_bwd_hdim256_bf16_causal_sm80.cu", - ], + sources=( + [ + "csrc/flash_dmattn/flash_api.cpp", + ] + + sorted(glob.glob("csrc/flash_dmattn/src/instantiations/flash_*.cu")) + ), extra_compile_args={ "cxx": compiler_c17_flag, "nvcc": append_nvcc_threads(nvcc_flags + cc_flag),