1- # Flash Dynamic Mask Attention API 参考文档
1+ # Flash Sparse Attention API 参考文档
22
33
44## 概述
55
6- Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。
6+ Flash Sparse Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。
77
88
99## 目录
@@ -12,37 +12,36 @@ Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash
12122 . [ 快速开始] ( #快速开始 )
13133 . [ 后端选择与比较] ( #后端选择与比较 )
14144 . [ 接口函数详解] ( #接口函数详解 )
15- - [ CUDA 后端:flash_dmattn_func ] ( #flash_dmattn_func -cuda-后端 )
16- - [ Triton 后端:triton_dmattn_func ] ( #triton_dmattn_func -triton-后端 )
17- - [ Flex 后端:flex_dmattn_func ] ( #flex_dmattn_func -flex-后端 )
15+ - [ CUDA 后端:flash_sparse_attn_func ] ( #flash_sparse_attn_func -cuda-后端 )
16+ - [ Triton 后端:triton_sparse_attn_func ] ( #triton_sparse_attn_func -triton-后端 )
17+ - [ Flex 后端:flex_sparse_attn_func ] ( #flex_sparse_attn_func -flex-后端 )
18185 . [ 集成] ( #集成 )
1919 - [ Transformers 集成] ( #transformers-集成 )
20206 . [ 常见问题与解决方案] ( #常见问题与解决方案 )
2121
2222
2323## 安装
2424
25- 请参考 [ README] ( https://github.com/SmallDoges/flash-dmattn /blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1 ) 以获取详细的安装说明和依赖项。
25+ 请参考 [ README] ( https://github.com/SmallDoges/flash-sparse-attention /blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1 ) 以获取详细的安装说明和依赖项。
2626
2727``` bash
2828# 使用 CUDA 后端
29- pip install flash-dmattn
30-
29+ pip install flash-sparse-attn
3130# 或从源码安装
3231pip install -e .
3332
3433# 仅使用 Triton/Flex 后端
35- FLASH_DMATTN_SKIP_CUDA_BUILD =1 pip install -e .
34+ FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD =1 pip install -e .
3635```
3736
3837
3938## 快速开始
4039
41- 使用 ` flash_dmattn_func_auto ` 可以自动选择最佳可用后端,无需手动判断。
40+ 使用 ` flash_sparse_attn_func_auto ` 可以自动选择最佳可用后端,无需手动判断。
4241
4342``` python
4443import torch
45- from flash_dmattn import flash_dmattn_func_auto
44+ from flash_sparse_attn import flash_sparse_attn_func_auto
4645
4746# 准备输入张量
4847batch, seqlen, num_heads, head_dim = 2 , 1024 , 8 , 64
@@ -51,27 +50,27 @@ k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device
5150v = torch.randn(batch, seqlen, num_heads, head_dim, dtype = torch.bfloat16, device = ' cuda' )
5251
5352# 获取注意力函数(自动选择后端,优先级: cuda > triton > flex)
54- attn_func = flash_dmattn_func_auto ()
53+ attn_func = flash_sparse_attn_func_auto ()
5554
5655# 调用注意力计算
5756output = attn_func(q, k, v, is_causal = True )
5857print (f " 输出形状: { output.shape} " ) # (2, 1024, 8, 64)
5958
6059# 也可以强制使用特定后端
61- attn_func = flash_dmattn_func_auto (backend = " cuda" ) # 或 "triton", "flex"
60+ attn_func = flash_sparse_attn_func_auto (backend = " cuda" ) # 或 "triton", "flex"
6261output = attn_func(q, k, v, is_causal = True )
6362```
6463
6564> [ !NOTE]
66- > ` flash_dmattn_func_auto ` 返回一个可调用的注意力函数,而不是注意力输出。
65+ > ` flash_sparse_attn_func_auto ` 返回一个可调用的注意力函数,而不是注意力输出。
6766
6867
6968## 后端选择与比较
7069
7170### 可用后端检查
7271
7372``` python
74- from flash_dmattn import get_available_backends, CUDA_AVAILABLE , TRITON_AVAILABLE , FLEX_AVAILABLE
73+ from flash_sparse_attn import get_available_backends, CUDA_AVAILABLE , TRITON_AVAILABLE , FLEX_AVAILABLE
7574
7675# 查看所有可用后端
7776print (get_available_backends()) # 例如:["cuda", "triton", "flex"]
@@ -101,19 +100,19 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL
101100
102101### 何时使用各个后端
103102
104- ** CUDA 后端** ([ 详细说明] ( #flash_dmattn_func -cuda-后端 ) )
103+ ** CUDA 后端** ([ 详细说明] ( #flash_sparse_attn_func -cuda-后端 ) )
105104- ✅ 完整梯度支持的训练工作负载
106105- ✅ 最大性能生产推理
107106- ✅ 需要确定性行为的应用
108107- ❌ 避免:无法构建自定义 CUDA 扩展时
109108
110- ** Triton 后端** ([ 详细说明] ( #triton_dmattn_func -triton-后端 ) )
109+ ** Triton 后端** ([ 详细说明] ( #triton_sparse_attn_func -triton-后端 ) )
111110- ✅ CUDA 扩展不可用时的训练工作负载
112111- ✅ 开发和原型设计
113112- ✅ 跨平台兼容性需求
114113- ✅ 性能和易安装性的良好平衡
115114
116- ** Flex 后端** ([ 详细说明] ( #flex_dmattn_func -flex-后端 ) )
115+ ** Flex 后端** ([ 详细说明] ( #flex_sparse_attn_func -flex-后端 ) )
117116- ✅ 仅推理应用
118117- ✅ 使用最新 PyTorch 特性的研究
119118- ✅ 无需自定义构建的快速实验
@@ -123,15 +122,15 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL
123122### 导入可用函数
124123
125124``` python
126- from flash_dmattn import (
125+ from flash_sparse_attn import (
127126 # 自动后端选择
128127 get_available_backends,
129- flash_dmattn_func_auto ,
128+ flash_sparse_attn_func_auto ,
130129
131130 # 后端特定函数
132- flash_dmattn_func , # CUDA 后端
133- triton_dmattn_func , # Triton 后端
134- flex_dmattn_func , # Flex 后端
131+ flash_sparse_attn_func , # CUDA 后端
132+ triton_sparse_attn_func , # Triton 后端
133+ flex_sparse_attn_func , # Flex 后端
135134
136135 # 后端可用性标志
137136 CUDA_AVAILABLE ,
@@ -140,20 +139,20 @@ from flash_dmattn import (
140139)
141140
142141# Transformers 集成
143- from flash_dmattn .integrations.flash_dynamic_mask_attention import (
144- flash_dynamic_mask_attention_forward
142+ from flash_sparse_attn .integrations.flash_sparse_attention import (
143+ flash_sparse_attention_forward
145144)
146145```
147146
148147
149148## 接口函数详解
150149
151- ### flash_dmattn_func (CUDA 后端)
150+ ### flash_sparse_attn_func (CUDA 后端)
152151
153152主要的注意力函数。支持多头注意力和分组查询注意力(当 KV 头数少于 Q 头数时)。需要 CUDA 扩展已构建并可用。
154153
155154``` python
156- def flash_dmattn_func (
155+ def flash_sparse_attn_func (
157156 query : torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
158157 key : torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
159158 value : torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
@@ -182,12 +181,12 @@ def flash_dmattn_func(
182181
183182- output: (B, Q, H, D)
184183
185- # ## triton_dmattn_func (Triton 后端)
184+ # ## triton_sparse_attn_func (Triton 后端)
186185
187186基于 Triton 的实现,无需自定义 CUDA 内核即可提供良好性能。
188187
189188```python
190- def triton_dmattn_func (
189+ def triton_sparse_attn_func (
191190 query : torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
192191 key : torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
193192 value : torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
@@ -198,12 +197,12 @@ def triton_dmattn_func(
198197) -> torch.Tensor
199198```
200199
201- # ## flex_dmattn_func (Flex Attention 后端)
200+ # ## flex_sparse_attn_func (Flex Attention 后端)
202201
203202基于 Flex Attention 的实现,使用 PyTorch 原生 flex attention 并支持动态掩码。
204203
205204```python
206- def flex_dmattn_func (
205+ def flex_sparse_attn_func (
207206 query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
208207 key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
209208 value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
@@ -219,14 +218,14 @@ def flex_dmattn_func(
219218
220219# ## Transformers 集成
221220
222- 为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash dynamic mask attention 支持。
221+ 为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash sparse attention 支持。
223222
224- # ### flash_dynamic_mask_attention_forward
223+ # ### flash_sparse_attention_forward
225224
226225```python
227- from flash_dmattn .integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
226+ from flash_sparse_attn .integrations.flash_sparse_attention import flash_sparse_attention_forward
228227
229- def flash_dynamic_mask_attention_forward (
228+ def flash_sparse_attention_forward (
230229 module: torch.nn.Module, # 注意力模块
231230 query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim)
232231 key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim)
@@ -253,7 +252,7 @@ def flash_dynamic_mask_attention_forward(
253252 - is_causal: 是否应用因果掩码
254253 - window_size: 保持的窗口大小
255254 - layer_idx: 用于日志的层索引
256- - implementation: 使用的实现(" flash_dmattn " 或 None )
255+ - implementation: 使用的实现(" flash_sparse_attn " 或 None )
257256
258257# ### 返回值
259258
@@ -267,7 +266,7 @@ import torch.nn as nn
267266import torch.nn.functional as F
268267from typing import Optional, Callable, tuple
269268from transformers.cache_utils import Cache
270- from flash_dmattn .integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
269+ from flash_sparse_attn .integrations.flash_sparse_attention import flash_sparse_attention_forward
271270
272271class DynamicMaskAttention (nn .Module ):
273272 def __init__ (self , config , layer_idx : Optional[int ] = None ):
@@ -331,7 +330,7 @@ class DynamicMaskAttention(nn.Module):
331330 attn_bias = torch.exp(self .A * F.softplus(dt_states)).transpose(- 1 , - 2 ).to(hidden_states.dtype)
332331
333332 # 选择注意力实现
334- attention_interface: Callable = flash_dynamic_mask_attention_forward
333+ attention_interface: Callable = flash_sparse_attention_forward
335334
336335 attn_output, attn_weights = attention_interface(
337336 self ,
@@ -361,7 +360,7 @@ class DynamicMaskAttention(nn.Module):
361360
362361``` python
363362try :
364- from flash_dmattn import flash_dmattn_func_auto , get_available_backends
363+ from flash_sparse_attn import flash_sparse_attn_func_auto , get_available_backends
365364 print (" ✅ 导入成功" , get_available_backends())
366365except ImportError as e:
367366 print (f " ❌ 导入失败: { e} " )
@@ -384,10 +383,10 @@ except ImportError as e:
384383
385384``` python
386385import torch
387- from flash_dmattn import flash_dmattn_func_auto
386+ from flash_sparse_attn import flash_sparse_attn_func_auto
388387
389388torch.autograd.set_detect_anomaly(True )
390- attn = flash_dmattn_func_auto ()
389+ attn = flash_sparse_attn_func_auto ()
391390output = attn(q, k, v, attn_mask = attn_mask, attn_bias = attn_bias, is_causal = True )
392391if torch.isnan(output).any():
393392 print (" ⚠️ 注意力输出中检测到 NaN" )
@@ -403,7 +402,7 @@ def print_memory_stats():
403402 print (f " 最大分配: { torch.cuda.max_memory_allocated() / 1e9 :.2f } GB " )
404403
405404print_memory_stats()
406- attn = flash_dmattn_func_auto ()
405+ attn = flash_sparse_attn_func_auto ()
407406output = attn(q, k, v)
408407print_memory_stats()
409408```
0 commit comments