Skip to content

Commit b9260c0

Browse files
committed
Updates docs to use mask utility in examples
Replaces manual top‑k mask construction with a utility-based dynamic sparse mask in the examples to reduce complexity and align with current API. Unifies variable names and updates example usage and gradient printouts across English and Chinese guides.
1 parent 0dbd673 commit b9260c0

File tree

2 files changed

+36
-38
lines changed

2 files changed

+36
-38
lines changed

README.md

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ pip install . --no-build-isolation
153153
```python
154154
import torch
155155
from flash_dmattn import flash_dmattn_func_auto
156+
from flash_dmattn.utils.mask import create_mask
156157
import math
157158

158159
# Setup
@@ -167,22 +168,20 @@ query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dty
167168
key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
168169
value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
169170

170-
# Create mask and bias for sparse attention
171-
attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
172-
attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
171+
# Create bias for sparse attention
172+
attn_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
173173

174-
# Generate sparse mask based on bias
174+
# Generate dynamic mask based on bias
175175
if seq_len > window_size:
176-
# Select top-k most important keys for each query
177-
topk_values, topk_indices = torch.topk(
178-
attention_bias, window_size, dim=-1,
179-
largest=True, sorted=False
176+
attn_mask = create_mask(
177+
attention_bias=attn_bias,
178+
attention_mask=attn_mask,
179+
batch_size=batch_size,
180+
query_len=seq_len,
181+
key_len=seq_len,
182+
window_size=window_size,
183+
min_dtype=min_dtype,
180184
)
181-
# Generate valid top-k mask
182-
valid_topk = (topk_values != min_dtype).to(dtype)
183-
attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device)
184-
attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk)
185-
attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype)
186185

187186
# Select FDMA kernel
188187
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
@@ -192,8 +191,8 @@ output = flash_dmattn_func(
192191
query=query,
193192
key=key,
194193
value=value,
195-
attn_mask=attention_mask,
196-
attn_bias=attention_bias,
194+
attn_mask=attn_mask,
195+
attn_bias=attn_bias,
197196
is_causal=True,
198197
softmax_scale=1.0/math.sqrt(head_dim),
199198
)
@@ -208,13 +207,13 @@ print(f"Output shape: {output.shape}") # [1, 256, 2, 64]
208207
query.requires_grad_(True)
209208
key.requires_grad_(True)
210209
value.requires_grad_(True)
211-
attention_bias.requires_grad_(True)
210+
attn_bias.requires_grad_(True)
212211

213212
# Forward pass
214213
output = flash_dmattn_func(
215214
query=query, key=key, value=value,
216-
attn_mask=attention_mask,
217-
attn_bias=attention_bias,
215+
attn_mask=attn_mask,
216+
attn_bias=attn_bias,
218217
is_causal=True,
219218
softmax_scale=1.0/math.sqrt(head_dim)
220219
)
@@ -226,7 +225,7 @@ loss.backward()
226225
print(f"Query gradient shape: {query.grad.shape}")
227226
print(f"Key gradient shape: {key.grad.shape}")
228227
print(f"Value gradient shape: {value.grad.shape}")
229-
print(f"Bias gradient shape: {attention_bias.grad.shape}")
228+
print(f"Bias gradient shape: {attn_bias.grad.shape}")
230229
```
231230

232231

README_zh.md

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ pip install . --no-build-isolation
153153
```python
154154
import torch
155155
from flash_dmattn import flash_dmattn_func_auto
156+
from flash_dmattn.utils.mask import create_mask
156157
import math
157158

158159
# 设置
@@ -167,22 +168,20 @@ query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dty
167168
key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
168169
value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
169170

170-
# 为稀疏注意力创建 mask 和 bias
171-
attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
172-
attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
171+
# 为稀疏注意力创建 bias
172+
attn_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
173173

174-
# 基于 bias 生成稀疏 mask
174+
# 基于 bias 生成动态 mask
175175
if seq_len > window_size:
176-
# 为每个查询选择 top-k 最重要的键
177-
topk_values, topk_indices = torch.topk(
178-
attention_bias, window_size, dim=-1,
179-
largest=True, sorted=False
176+
attn_mask = create_mask(
177+
attention_bias=attn_bias,
178+
attention_mask=attn_mask,
179+
batch_size=batch_size,
180+
query_len=seq_len,
181+
key_len=seq_len,
182+
window_size=window_size,
183+
min_dtype=min_dtype,
180184
)
181-
# 生成有效的 top-k mask
182-
valid_topk = (topk_values != min_dtype).to(dtype)
183-
attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device)
184-
attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk)
185-
attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype)
186185

187186
# 选择 FDMA 内核
188187
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
@@ -192,8 +191,8 @@ output = flash_dmattn_func(
192191
query=query,
193192
key=key,
194193
value=value,
195-
attn_mask=attention_mask,
196-
attn_bias=attention_bias,
194+
attn_mask=attn_mask,
195+
attn_bias=attn_bias,
197196
is_causal=True,
198197
softmax_scale=1.0/math.sqrt(head_dim),
199198
)
@@ -208,13 +207,13 @@ print(f"输出形状: {output.shape}") # [1, 256, 2, 64]
208207
query.requires_grad_(True)
209208
key.requires_grad_(True)
210209
value.requires_grad_(True)
211-
attention_bias.requires_grad_(True)
210+
attn_bias.requires_grad_(True)
212211

213212
# 前向传播
214213
output = flash_dmattn_func(
215214
query=query, key=key, value=value,
216-
attn_mask=attention_mask,
217-
attn_bias=attention_bias,
215+
attn_mask=attn_mask,
216+
attn_bias=attn_bias,
218217
is_causal=True,
219218
softmax_scale=1.0/math.sqrt(head_dim)
220219
)
@@ -226,7 +225,7 @@ loss.backward()
226225
print(f"Query 梯度形状: {query.grad.shape}")
227226
print(f"Key 梯度形状: {key.grad.shape}")
228227
print(f"Value 梯度形状: {value.grad.shape}")
229-
print(f"Bias 梯度形状: {attention_bias.grad.shape}")
228+
print(f"Bias 梯度形状: {attn_bias.grad.shape}")
230229
```
231230

232231

0 commit comments

Comments
 (0)