Skip to content

Commit b3ac56f

Browse files
committed
Renames flash attention variant
Updates the sparse attention backend to drop the old dynamic mask name so future errors and docs consistently refer to FlashSparseAttention.
1 parent a0ed87d commit b3ac56f

File tree

304 files changed

+19
-19
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

304 files changed

+19
-19
lines changed

csrc/flash_dmattn/flash_api.cpp renamed to csrc/flash_sparse_attn/flash_api.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void set_params_fprop(
126126

127127
// Set the different scale values.
128128
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
129-
TORCH_CHECK(softcap <= 0.0, "This flash dynamic mask attention build does not support softcap.");
129+
TORCH_CHECK(softcap <= 0.0, "This flash sparse attention build does not support softcap.");
130130
#endif
131131
if (softcap > 0.0) {
132132
params.softcap = softmax_scale / softcap;
@@ -145,7 +145,7 @@ void set_params_fprop(
145145
params.is_seqlens_k_cumulative = true;
146146

147147
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
148-
TORCH_CHECK(d == d_rounded, "This flash dynamic mask attention build does not support headdim not being a multiple of 32.");
148+
TORCH_CHECK(d == d_rounded, "This flash sparse attention build does not support headdim not being a multiple of 32.");
149149
#endif
150150

151151
params.unpadded_lse = unpadded_lse;
@@ -366,10 +366,10 @@ mha_fwd(
366366
at::cuda::CUDAGuard device_guard{q.device()};
367367
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
368368
bool is_sm8x_min = cc_major >= 8;
369-
TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer.");
369+
TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer.");
370370

371371
auto q_dtype = q.dtype();
372-
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
372+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type");
373373
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
374374
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
375375

@@ -420,7 +420,7 @@ mha_fwd(
420420
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
421421

422422
TORCH_CHECK(batch_size > 0, "batch size must be positive");
423-
TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256");
423+
TORCH_CHECK(head_size <= 256, "FlashSparseAttention forward only supports head dimension at most 256");
424424
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
425425
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
426426

@@ -577,10 +577,10 @@ mha_varlen_fwd(
577577
at::cuda::CUDAGuard device_guard{q.device()};
578578
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
579579
bool is_sm8x_min = cc_major >= 8;
580-
TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer.");
580+
TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer.");
581581

582582
auto q_dtype = q.dtype();
583-
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
583+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type");
584584
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
585585
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
586586
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
@@ -644,7 +644,7 @@ mha_varlen_fwd(
644644
const int total_q = q.sizes()[0];
645645

646646
TORCH_CHECK(batch_size > 0, "batch size must be positive");
647-
TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256");
647+
TORCH_CHECK(head_size <= 256, "FlashSparseAttention forward only supports head dimension at most 256");
648648
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
649649
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
650650

@@ -810,19 +810,19 @@ mha_bwd(
810810
) {
811811

812812
#ifdef FLASHATTENTION_DISABLE_BACKWARD
813-
TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward.");
813+
TORCH_CHECK(false, "This flash sparse attention build does not support backward.");
814814
#endif
815815

816816
// Otherwise the kernel will be launched from cuda:0 device
817817
at::cuda::CUDAGuard device_guard{q.device()};
818818
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
819819
bool is_sm8x_min = cc_major >= 8;
820-
TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer.");
820+
TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer.");
821821

822822
auto stream = at::cuda::getCurrentCUDAStream().stream();
823823

824824
auto q_dtype = q.dtype();
825-
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
825+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type");
826826
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
827827
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
828828
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
@@ -881,7 +881,7 @@ mha_bwd(
881881

882882
TORCH_CHECK(batch_size > 0, "batch size must be positive");
883883
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
884-
TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256");
884+
TORCH_CHECK(head_size <= 256, "FlashSparseAttention backward only supports head dimension at most 256");
885885
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
886886

887887
if (has_mask) {
@@ -1072,19 +1072,19 @@ mha_varlen_bwd(
10721072
) {
10731073

10741074
#ifdef FLASHATTENTION_DISABLE_BACKWARD
1075-
TORCH_CHECK(false, "This flash dynamic mask attention build does not support backward.");
1075+
TORCH_CHECK(false, "This flash sparse attention build does not support backward.");
10761076
#endif
10771077

10781078
// Otherwise the kernel will be launched from cuda:0 device
10791079
at::cuda::CUDAGuard device_guard{q.device()};
10801080
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
10811081
bool is_sm8x_min = cc_major >= 8;
1082-
TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer.");
1082+
TORCH_CHECK(is_sm8x_min, "FlashSparseAttention only supports Ampere GPUs or newer.");
10831083

10841084
auto stream = at::cuda::getCurrentCUDAStream().stream();
10851085

10861086
auto q_dtype = q.dtype();
1087-
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
1087+
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashSparseAttention only support fp16 and bf16 data type");
10881088
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
10891089
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
10901090
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
@@ -1124,7 +1124,7 @@ mha_varlen_bwd(
11241124
const int num_heads_bias = has_bias ? bias.size(1) : 1;
11251125
TORCH_CHECK(batch_size > 0, "batch size must be positive");
11261126
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1127-
TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention backward only supports head dimension at most 256");
1127+
TORCH_CHECK(head_size <= 256, "FlashSparseAttention backward only supports head dimension at most 256");
11281128
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
11291129

11301130
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
@@ -1268,7 +1268,7 @@ mha_varlen_bwd(
12681268
} // namespace FLASH_NAMESPACE
12691269

12701270
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1271-
m.doc() = "FlashDynamicMaskAttention";
1271+
m.doc() = "FlashSparseAttention";
12721272
m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
12731273
m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length");
12741274
m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
File renamed without changes.
File renamed without changes.

csrc/flash_dmattn/src/flash_bwd_launch_template.h renamed to csrc/flash_sparse_attn/src/flash_bwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace FLASH_NAMESPACE {
2424
#endif
2525

2626
// Define a macro for unsupported architecture handling to centralize the error message
27-
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
27+
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashSparseAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
2828

2929
// Use a macro to clean up kernel definitions
3030
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \

csrc/flash_dmattn/src/flash_fwd_launch_template.h renamed to csrc/flash_sparse_attn/src/flash_fwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace FLASH_NAMESPACE {
2323
#endif
2424

2525
// Define a macro for unsupported architecture handling to centralize the error message
26-
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
26+
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashSparseAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
2727

2828
// Use a macro to clean up kernel definitions
2929
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
File renamed without changes.

0 commit comments

Comments
 (0)