Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ pip install . --no-build-isolation
```python
import torch
from flash_dmattn import flash_dmattn_func_auto
from flash_dmattn.utils.mask import create_mask
import math

# Setup
Expand All @@ -167,22 +168,20 @@ query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dty
key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)

# Create mask and bias for sparse attention
attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
# Create bias for sparse attention
attn_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)

Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on line 171 says 'Create bias for sparse attention' but the code creates attn_mask, not attn_bias. Additionally, attn_bias is used on line 177 but never defined. The line should create attn_bias instead of attn_mask, and the subsequent create_mask call should initialize attn_mask from None or an appropriate default.

Suggested change
attn_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
attn_bias = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
attn_mask = None

Copilot uses AI. Check for mistakes.
# Generate sparse mask based on bias
# Generate dynamic mask based on bias
if seq_len > window_size:
# Select top-k most important keys for each query
topk_values, topk_indices = torch.topk(
attention_bias, window_size, dim=-1,
largest=True, sorted=False
attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=attn_mask,
batch_size=batch_size,
query_len=seq_len,
key_len=seq_len,
window_size=window_size,
min_dtype=min_dtype,
)
# Generate valid top-k mask
valid_topk = (topk_values != min_dtype).to(dtype)
attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device)
attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk)
attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype)

# Select FDMA kernel
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
Expand All @@ -192,8 +191,8 @@ output = flash_dmattn_func(
query=query,
key=key,
value=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
attn_mask=attn_mask,
attn_bias=attn_bias,
is_causal=True,
softmax_scale=1.0/math.sqrt(head_dim),
)
Expand All @@ -208,13 +207,13 @@ print(f"Output shape: {output.shape}") # [1, 256, 2, 64]
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
attention_bias.requires_grad_(True)
attn_bias.requires_grad_(True)

# Forward pass
output = flash_dmattn_func(
query=query, key=key, value=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
attn_mask=attn_mask,
attn_bias=attn_bias,
is_causal=True,
softmax_scale=1.0/math.sqrt(head_dim)
)
Expand All @@ -226,7 +225,7 @@ loss.backward()
print(f"Query gradient shape: {query.grad.shape}")
print(f"Key gradient shape: {key.grad.shape}")
print(f"Value gradient shape: {value.grad.shape}")
print(f"Bias gradient shape: {attention_bias.grad.shape}")
print(f"Bias gradient shape: {attn_bias.grad.shape}")
```


Expand Down
37 changes: 18 additions & 19 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ pip install . --no-build-isolation
```python
import torch
from flash_dmattn import flash_dmattn_func_auto
from flash_dmattn.utils.mask import create_mask
import math

# 设置
Expand All @@ -167,22 +168,20 @@ query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dty
key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype)

# 为稀疏注意力创建 mask 和 bias
attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
# 为稀疏注意力创建 bias
attn_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype)
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable attn_mask is passed to create_mask on line 178 but is never defined before this usage. Either initialize attn_mask before the conditional block or pass None if the utility function supports it.

Copilot uses AI. Check for mistakes.

# 基于 bias 生成稀疏 mask
# 基于 bias 生成动态 mask
if seq_len > window_size:
# 为每个查询选择 top-k 最重要的键
topk_values, topk_indices = torch.topk(
attention_bias, window_size, dim=-1,
largest=True, sorted=False
attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=attn_mask,
batch_size=batch_size,
query_len=seq_len,
key_len=seq_len,
window_size=window_size,
min_dtype=min_dtype,
)
# 生成有效的 top-k mask
valid_topk = (topk_values != min_dtype).to(dtype)
attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device)
attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk)
attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype)

# 选择 FDMA 内核
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
Expand All @@ -192,8 +191,8 @@ output = flash_dmattn_func(
query=query,
key=key,
value=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
attn_mask=attn_mask,
attn_bias=attn_bias,
is_causal=True,
softmax_scale=1.0/math.sqrt(head_dim),
)
Expand All @@ -208,13 +207,13 @@ print(f"输出形状: {output.shape}") # [1, 256, 2, 64]
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
attention_bias.requires_grad_(True)
attn_bias.requires_grad_(True)

# 前向传播
output = flash_dmattn_func(
query=query, key=key, value=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
attn_mask=attn_mask,
attn_bias=attn_bias,
is_causal=True,
softmax_scale=1.0/math.sqrt(head_dim)
)
Expand All @@ -226,7 +225,7 @@ loss.backward()
print(f"Query 梯度形状: {query.grad.shape}")
print(f"Key 梯度形状: {key.grad.shape}")
print(f"Value 梯度形状: {value.grad.shape}")
print(f"Bias 梯度形状: {attention_bias.grad.shape}")
print(f"Bias 梯度形状: {attn_bias.grad.shape}")
```


Expand Down