Skip to content

Commit ac95f25

Browse files
committed
Aligns Chinese doc with sparse attention
Updates terminology to reflect the flash sparse attention rebranding so readers follow accurate package names, imports, and integration guidance.
1 parent 554e7e0 commit ac95f25

File tree

1 file changed

+41
-42
lines changed

1 file changed

+41
-42
lines changed

docs/api_reference_zh.md

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# Flash Dynamic Mask Attention API 参考文档
1+
# Flash Sparse Attention API 参考文档
22

33

44
## 概述
55

6-
Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。
6+
Flash Sparse Attention 是一个高性能注意力实现,结合了 Flash Attention 的内存效率和 Dynamic Mask Attention 的稀疏计算优势。它支持 CUDA、Triton 和 Flex Attention 后端,并支持超长序列的动态掩码。
77

88

99
## 目录
@@ -12,37 +12,36 @@ Flash Dynamic Mask Attention 是一个高性能注意力实现,结合了 Flash
1212
2. [快速开始](#快速开始)
1313
3. [后端选择与比较](#后端选择与比较)
1414
4. [接口函数详解](#接口函数详解)
15-
- [CUDA 后端:flash_dmattn_func](#flash_dmattn_func-cuda-后端)
16-
- [Triton 后端:triton_dmattn_func](#triton_dmattn_func-triton-后端)
17-
- [Flex 后端:flex_dmattn_func](#flex_dmattn_func-flex-后端)
15+
- [CUDA 后端:flash_sparse_attn_func](#flash_sparse_attn_func-cuda-后端)
16+
- [Triton 后端:triton_sparse_attn_func](#triton_sparse_attn_func-triton-后端)
17+
- [Flex 后端:flex_sparse_attn_func](#flex_sparse_attn_func-flex-后端)
1818
5. [集成](#集成)
1919
- [Transformers 集成](#transformers-集成)
2020
6. [常见问题与解决方案](#常见问题与解决方案)
2121

2222

2323
## 安装
2424

25-
请参考 [README](https://github.com/SmallDoges/flash-dmattn/blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1) 以获取详细的安装说明和依赖项。
25+
请参考 [README](https://github.com/SmallDoges/flash-sparse-attention/blob/main/README_zh.md#%E5%AE%89%E8%A3%85-1) 以获取详细的安装说明和依赖项。
2626

2727
```bash
2828
# 使用 CUDA 后端
29-
pip install flash-dmattn
30-
29+
pip install flash-sparse-attn
3130
# 或从源码安装
3231
pip install -e .
3332

3433
# 仅使用 Triton/Flex 后端
35-
FLASH_DMATTN_SKIP_CUDA_BUILD=1 pip install -e .
34+
FLASH_SPARSE_ATTENTION_SKIP_CUDA_BUILD=1 pip install -e .
3635
```
3736

3837

3938
## 快速开始
4039

41-
使用 `flash_dmattn_func_auto` 可以自动选择最佳可用后端,无需手动判断。
40+
使用 `flash_sparse_attn_func_auto` 可以自动选择最佳可用后端,无需手动判断。
4241

4342
```python
4443
import torch
45-
from flash_dmattn import flash_dmattn_func_auto
44+
from flash_sparse_attn import flash_sparse_attn_func_auto
4645

4746
# 准备输入张量
4847
batch, seqlen, num_heads, head_dim = 2, 1024, 8, 64
@@ -51,27 +50,27 @@ k = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device
5150
v = torch.randn(batch, seqlen, num_heads, head_dim, dtype=torch.bfloat16, device='cuda')
5251

5352
# 获取注意力函数(自动选择后端,优先级: cuda > triton > flex)
54-
attn_func = flash_dmattn_func_auto()
53+
attn_func = flash_sparse_attn_func_auto()
5554

5655
# 调用注意力计算
5756
output = attn_func(q, k, v, is_causal=True)
5857
print(f"输出形状: {output.shape}") # (2, 1024, 8, 64)
5958

6059
# 也可以强制使用特定后端
61-
attn_func = flash_dmattn_func_auto(backend="cuda") # 或 "triton", "flex"
60+
attn_func = flash_sparse_attn_func_auto(backend="cuda") # 或 "triton", "flex"
6261
output = attn_func(q, k, v, is_causal=True)
6362
```
6463

6564
> [!NOTE]
66-
> `flash_dmattn_func_auto` 返回一个可调用的注意力函数,而不是注意力输出。
65+
> `flash_sparse_attn_func_auto` 返回一个可调用的注意力函数,而不是注意力输出。
6766
6867

6968
## 后端选择与比较
7069

7170
### 可用后端检查
7271

7372
```python
74-
from flash_dmattn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE
73+
from flash_sparse_attn import get_available_backends, CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE
7574

7675
# 查看所有可用后端
7776
print(get_available_backends()) # 例如:["cuda", "triton", "flex"]
@@ -101,19 +100,19 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL
101100
102101
### 何时使用各个后端
103102

104-
**CUDA 后端** ([详细说明](#flash_dmattn_func-cuda-后端))
103+
**CUDA 后端** ([详细说明](#flash_sparse_attn_func-cuda-后端))
105104
- ✅ 完整梯度支持的训练工作负载
106105
- ✅ 最大性能生产推理
107106
- ✅ 需要确定性行为的应用
108107
- ❌ 避免:无法构建自定义 CUDA 扩展时
109108

110-
**Triton 后端** ([详细说明](#triton_dmattn_func-triton-后端))
109+
**Triton 后端** ([详细说明](#triton_sparse_attn_func-triton-后端))
111110
- ✅ CUDA 扩展不可用时的训练工作负载
112111
- ✅ 开发和原型设计
113112
- ✅ 跨平台兼容性需求
114113
- ✅ 性能和易安装性的良好平衡
115114

116-
**Flex 后端** ([详细说明](#flex_dmattn_func-flex-后端))
115+
**Flex 后端** ([详细说明](#flex_sparse_attn_func-flex-后端))
117116
- ✅ 仅推理应用
118117
- ✅ 使用最新 PyTorch 特性的研究
119118
- ✅ 无需自定义构建的快速实验
@@ -123,15 +122,15 @@ print(f"CUDA: {CUDA_AVAILABLE}, Triton: {TRITON_AVAILABLE}, Flex: {FLEX_AVAILABL
123122
### 导入可用函数
124123

125124
```python
126-
from flash_dmattn import (
125+
from flash_sparse_attn import (
127126
# 自动后端选择
128127
get_available_backends,
129-
flash_dmattn_func_auto,
128+
flash_sparse_attn_func_auto,
130129

131130
# 后端特定函数
132-
flash_dmattn_func, # CUDA 后端
133-
triton_dmattn_func, # Triton 后端
134-
flex_dmattn_func, # Flex 后端
131+
flash_sparse_attn_func, # CUDA 后端
132+
triton_sparse_attn_func, # Triton 后端
133+
flex_sparse_attn_func, # Flex 后端
135134

136135
# 后端可用性标志
137136
CUDA_AVAILABLE,
@@ -140,20 +139,20 @@ from flash_dmattn import (
140139
)
141140

142141
# Transformers 集成
143-
from flash_dmattn.integrations.flash_dynamic_mask_attention import (
144-
flash_dynamic_mask_attention_forward
142+
from flash_sparse_attn.integrations.flash_sparse_attention import (
143+
flash_sparse_attention_forward
145144
)
146145
```
147146

148147

149148
## 接口函数详解
150149

151-
### flash_dmattn_func (CUDA 后端)
150+
### flash_sparse_attn_func (CUDA 后端)
152151

153152
主要的注意力函数。支持多头注意力和分组查询注意力(当 KV 头数少于 Q 头数时)。需要 CUDA 扩展已构建并可用。
154153

155154
```python
156-
def flash_dmattn_func(
155+
def flash_sparse_attn_func(
157156
query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
158157
key: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
159158
value: torch.Tensor, # (batch, seqlen_k, num_kv_heads, head_dim)
@@ -182,12 +181,12 @@ def flash_dmattn_func(
182181

183182
- output: (B, Q, H, D)
184183

185-
### triton_dmattn_func (Triton 后端)
184+
### triton_sparse_attn_func (Triton 后端)
186185

187186
基于 Triton 的实现,无需自定义 CUDA 内核即可提供良好性能。
188187

189188
```python
190-
def triton_dmattn_func(
189+
def triton_sparse_attn_func(
191190
query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
192191
key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
193192
value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
@@ -198,12 +197,12 @@ def triton_dmattn_func(
198197
) -> torch.Tensor
199198
```
200199

201-
### flex_dmattn_func (Flex Attention 后端)
200+
### flex_sparse_attn_func (Flex Attention 后端)
202201

203202
基于 Flex Attention 的实现,使用 PyTorch 原生 flex attention 并支持动态掩码。
204203

205204
```python
206-
def flex_dmattn_func(
205+
def flex_sparse_attn_func(
207206
query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim)
208207
key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
209208
value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim)
@@ -219,14 +218,14 @@ def flex_dmattn_func(
219218

220219
### Transformers 集成
221220

222-
为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash dynamic mask attention 支持。
221+
为 HuggingFace Transformers 模型提供的集成函数,提供无缝的 flash sparse attention 支持。
223222

224-
#### flash_dynamic_mask_attention_forward
223+
#### flash_sparse_attention_forward
225224

226225
```python
227-
from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
226+
from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward
228227

229-
def flash_dynamic_mask_attention_forward(
228+
def flash_sparse_attention_forward(
230229
module: torch.nn.Module, # 注意力模块
231230
query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim)
232231
key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim)
@@ -253,7 +252,7 @@ def flash_dynamic_mask_attention_forward(
253252
- is_causal: 是否应用因果掩码
254253
- window_size: 保持的窗口大小
255254
- layer_idx: 用于日志的层索引
256-
- implementation: 使用的实现("flash_dmattn"None
255+
- implementation: 使用的实现("flash_sparse_attn"None
257256

258257
#### 返回值
259258

@@ -267,7 +266,7 @@ import torch.nn as nn
267266
import torch.nn.functional as F
268267
from typing import Optional, Callable, tuple
269268
from transformers.cache_utils import Cache
270-
from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward
269+
from flash_sparse_attn.integrations.flash_sparse_attention import flash_sparse_attention_forward
271270

272271
class DynamicMaskAttention(nn.Module):
273272
def __init__(self, config, layer_idx: Optional[int] = None):
@@ -331,7 +330,7 @@ class DynamicMaskAttention(nn.Module):
331330
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
332331

333332
# 选择注意力实现
334-
attention_interface: Callable = flash_dynamic_mask_attention_forward
333+
attention_interface: Callable = flash_sparse_attention_forward
335334

336335
attn_output, attn_weights = attention_interface(
337336
self,
@@ -361,7 +360,7 @@ class DynamicMaskAttention(nn.Module):
361360

362361
```python
363362
try:
364-
from flash_dmattn import flash_dmattn_func_auto, get_available_backends
363+
from flash_sparse_attn import flash_sparse_attn_func_auto, get_available_backends
365364
print("✅ 导入成功", get_available_backends())
366365
except ImportError as e:
367366
print(f"❌ 导入失败: {e}")
@@ -384,10 +383,10 @@ except ImportError as e:
384383

385384
```python
386385
import torch
387-
from flash_dmattn import flash_dmattn_func_auto
386+
from flash_sparse_attn import flash_sparse_attn_func_auto
388387

389388
torch.autograd.set_detect_anomaly(True)
390-
attn = flash_dmattn_func_auto()
389+
attn = flash_sparse_attn_func_auto()
391390
output = attn(q, k, v, attn_mask=attn_mask, attn_bias=attn_bias, is_causal=True)
392391
if torch.isnan(output).any():
393392
print("⚠️ 注意力输出中检测到 NaN")
@@ -403,7 +402,7 @@ def print_memory_stats():
403402
print(f"最大分配: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
404403

405404
print_memory_stats()
406-
attn = flash_dmattn_func_auto()
405+
attn = flash_sparse_attn_func_auto()
407406
output = attn(q, k, v)
408407
print_memory_stats()
409408
```

0 commit comments

Comments
 (0)