@@ -126,7 +126,7 @@ void set_params_fprop(
126126
127127 // Set the different scale values.
128128 #ifdef FLASHATTENTION_DISABLE_SOFTCAP
129- TORCH_CHECK (softcap <= 0.0 , " This flash dynamic mask attention build does not support softcap." );
129+ TORCH_CHECK (softcap <= 0.0 , " This flash sparse attention build does not support softcap." );
130130 #endif
131131 if (softcap > 0.0 ) {
132132 params.softcap = softmax_scale / softcap;
@@ -145,7 +145,7 @@ void set_params_fprop(
145145 params.is_seqlens_k_cumulative = true ;
146146
147147 #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
148- TORCH_CHECK (d == d_rounded, " This flash dynamic mask attention build does not support headdim not being a multiple of 32." );
148+ TORCH_CHECK (d == d_rounded, " This flash sparse attention build does not support headdim not being a multiple of 32." );
149149 #endif
150150
151151 params.unpadded_lse = unpadded_lse;
@@ -366,10 +366,10 @@ mha_fwd(
366366 at::cuda::CUDAGuard device_guard{q.device ()};
367367 auto [cc_major, cc_minor] = get_compute_capability (get_current_device ());
368368 bool is_sm8x_min = cc_major >= 8 ;
369- TORCH_CHECK (is_sm8x_min, " FlashDynamicMaskAttention only supports Ampere GPUs or newer." );
369+ TORCH_CHECK (is_sm8x_min, " FlashSparseAttention only supports Ampere GPUs or newer." );
370370
371371 auto q_dtype = q.dtype ();
372- TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 , " FlashDynamicMaskAttention only support fp16 and bf16 data type" );
372+ TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 , " FlashSparseAttention only support fp16 and bf16 data type" );
373373 TORCH_CHECK (k.dtype () == q_dtype, " query and key must have the same dtype" );
374374 TORCH_CHECK (v.dtype () == q_dtype, " query and value must have the same dtype" );
375375
@@ -420,7 +420,7 @@ mha_fwd(
420420 const int seqlen_k_rounded = round_multiple (seqlen_k, 128 );
421421
422422 TORCH_CHECK (batch_size > 0 , " batch size must be positive" );
423- TORCH_CHECK (head_size <= 256 , " FlashDynamicMaskAttention forward only supports head dimension at most 256" );
423+ TORCH_CHECK (head_size <= 256 , " FlashSparseAttention forward only supports head dimension at most 256" );
424424 TORCH_CHECK (head_size % 8 == 0 , " query, key, value, and out_ must have a head_size that is a multiple of 8" );
425425 TORCH_CHECK (num_heads % num_heads_k == 0 , " Number of heads in key/value must divide number of heads in query" );
426426
@@ -577,10 +577,10 @@ mha_varlen_fwd(
577577 at::cuda::CUDAGuard device_guard{q.device ()};
578578 auto [cc_major, cc_minor] = get_compute_capability (get_current_device ());
579579 bool is_sm8x_min = cc_major >= 8 ;
580- TORCH_CHECK (is_sm8x_min, " FlashDynamicMaskAttention only supports Ampere GPUs or newer." );
580+ TORCH_CHECK (is_sm8x_min, " FlashSparseAttention only supports Ampere GPUs or newer." );
581581
582582 auto q_dtype = q.dtype ();
583- TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 , " FlashDynamicMaskAttention only support fp16 and bf16 data type" );
583+ TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 , " FlashSparseAttention only support fp16 and bf16 data type" );
584584 TORCH_CHECK (k.dtype () == q_dtype, " query and key must have the same dtype" );
585585 TORCH_CHECK (v.dtype () == q_dtype, " query and value must have the same dtype" );
586586 TORCH_CHECK (cu_seqlens_q.dtype () == torch::kInt32 , " cu_seqlens_q must have dtype int32" );
@@ -644,7 +644,7 @@ mha_varlen_fwd(
644644 const int total_q = q.sizes ()[0 ];
645645
646646 TORCH_CHECK (batch_size > 0 , " batch size must be positive" );
647- TORCH_CHECK (head_size <= 256 , " FlashDynamicMaskAttention forward only supports head dimension at most 256" );
647+ TORCH_CHECK (head_size <= 256 , " FlashSparseAttention forward only supports head dimension at most 256" );
648648 TORCH_CHECK (head_size % 8 == 0 , " query, key, value, and out_ must have a head_size that is a multiple of 8" );
649649 TORCH_CHECK (num_heads % num_heads_k == 0 , " Number of heads in key/value must divide number of heads in query" );
650650
@@ -810,19 +810,19 @@ mha_bwd(
810810) {
811811
812812 #ifdef FLASHATTENTION_DISABLE_BACKWARD
813- TORCH_CHECK (false , " This flash dynamic mask attention build does not support backward." );
813+ TORCH_CHECK (false , " This flash sparse attention build does not support backward." );
814814 #endif
815815
816816 // Otherwise the kernel will be launched from cuda:0 device
817817 at::cuda::CUDAGuard device_guard{q.device ()};
818818 auto [cc_major, cc_minor] = get_compute_capability (get_current_device ());
819819 bool is_sm8x_min = cc_major >= 8 ;
820- TORCH_CHECK (is_sm8x_min, " FlashDynamicMaskAttention only supports Ampere GPUs or newer." );
820+ TORCH_CHECK (is_sm8x_min, " FlashSparseAttention only supports Ampere GPUs or newer." );
821821
822822 auto stream = at::cuda::getCurrentCUDAStream ().stream ();
823823
824824 auto q_dtype = q.dtype ();
825- TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 , " FlashDynamicMaskAttention only support fp16 and bf16 data type" );
825+ TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 , " FlashSparseAttention only support fp16 and bf16 data type" );
826826 TORCH_CHECK (k.dtype () == q_dtype, " query and key must have the same dtype" );
827827 TORCH_CHECK (v.dtype () == q_dtype, " query and value must have the same dtype" );
828828 TORCH_CHECK (out.dtype () == q_dtype, " query and out must have the same dtype" );
@@ -881,7 +881,7 @@ mha_bwd(
881881
882882 TORCH_CHECK (batch_size > 0 , " batch size must be positive" );
883883 TORCH_CHECK (head_size % 8 == 0 , " head_size should be a multiple of 8" );
884- TORCH_CHECK (head_size <= 256 , " FlashDynamicMaskAttention backward only supports head dimension at most 256" );
884+ TORCH_CHECK (head_size <= 256 , " FlashSparseAttention backward only supports head dimension at most 256" );
885885 TORCH_CHECK (num_heads % num_heads_k == 0 , " Number of heads in key/value must divide number of heads in query" );
886886
887887 if (has_mask) {
@@ -1072,19 +1072,19 @@ mha_varlen_bwd(
10721072) {
10731073
10741074 #ifdef FLASHATTENTION_DISABLE_BACKWARD
1075- TORCH_CHECK (false , " This flash dynamic mask attention build does not support backward." );
1075+ TORCH_CHECK (false , " This flash sparse attention build does not support backward." );
10761076 #endif
10771077
10781078 // Otherwise the kernel will be launched from cuda:0 device
10791079 at::cuda::CUDAGuard device_guard{q.device ()};
10801080 auto [cc_major, cc_minor] = get_compute_capability (get_current_device ());
10811081 bool is_sm8x_min = cc_major >= 8 ;
1082- TORCH_CHECK (is_sm8x_min, " FlashDynamicMaskAttention only supports Ampere GPUs or newer." );
1082+ TORCH_CHECK (is_sm8x_min, " FlashSparseAttention only supports Ampere GPUs or newer." );
10831083
10841084 auto stream = at::cuda::getCurrentCUDAStream ().stream ();
10851085
10861086 auto q_dtype = q.dtype ();
1087- TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 , " FlashDynamicMaskAttention only support fp16 and bf16 data type" );
1087+ TORCH_CHECK (q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 , " FlashSparseAttention only support fp16 and bf16 data type" );
10881088 TORCH_CHECK (k.dtype () == q_dtype, " query and key must have the same dtype" );
10891089 TORCH_CHECK (v.dtype () == q_dtype, " query and value must have the same dtype" );
10901090 TORCH_CHECK (out.dtype () == q_dtype, " query and out must have the same dtype" );
@@ -1124,7 +1124,7 @@ mha_varlen_bwd(
11241124 const int num_heads_bias = has_bias ? bias.size (1 ) : 1 ;
11251125 TORCH_CHECK (batch_size > 0 , " batch size must be positive" );
11261126 TORCH_CHECK (head_size % 8 == 0 , " head_size should be a multiple of 8" );
1127- TORCH_CHECK (head_size <= 256 , " FlashDynamicMaskAttention backward only supports head dimension at most 256" );
1127+ TORCH_CHECK (head_size <= 256 , " FlashSparseAttention backward only supports head dimension at most 256" );
11281128 TORCH_CHECK (num_heads % num_heads_k == 0 , " Number of heads in key/value must divide number of heads in query" );
11291129
11301130 auto round_multiple = [](int x, int m) { return (x + m - 1 ) / m * m; };
@@ -1268,7 +1268,7 @@ mha_varlen_bwd(
12681268} // namespace FLASH_NAMESPACE
12691269
12701270PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
1271- m.doc () = " FlashDynamicMaskAttention " ;
1271+ m.doc () = " FlashSparseAttention " ;
12721272 m.def (" fwd" , &FLASH_NAMESPACE::mha_fwd, " Forward pass" );
12731273 m.def (" varlen_fwd" , &FLASH_NAMESPACE::mha_varlen_fwd, " Forward pass with variable length" );
12741274 m.def (" bwd" , &FLASH_NAMESPACE::mha_bwd, " Backward pass" );
0 commit comments