-
Notifications
You must be signed in to change notification settings - Fork 40
Update documentation to use mask utility in examples #198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -153,6 +153,7 @@ pip install . --no-build-isolation | |
| ```python | ||
| import torch | ||
| from flash_dmattn import flash_dmattn_func_auto | ||
| from flash_dmattn.utils.mask import create_mask | ||
| import math | ||
|
|
||
| # 设置 | ||
|
|
@@ -167,22 +168,20 @@ query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dty | |
| key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) | ||
| value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) | ||
|
|
||
| # 为稀疏注意力创建 mask 和 bias | ||
| attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) | ||
| attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) | ||
| # 为稀疏注意力创建 bias | ||
| attn_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) | ||
|
||
|
|
||
| # 基于 bias 生成稀疏 mask | ||
| # 基于 bias 生成动态 mask | ||
| if seq_len > window_size: | ||
| # 为每个查询选择 top-k 最重要的键 | ||
| topk_values, topk_indices = torch.topk( | ||
| attention_bias, window_size, dim=-1, | ||
| largest=True, sorted=False | ||
| attn_mask = create_mask( | ||
| attention_bias=attn_bias, | ||
| attention_mask=attn_mask, | ||
| batch_size=batch_size, | ||
| query_len=seq_len, | ||
| key_len=seq_len, | ||
| window_size=window_size, | ||
| min_dtype=min_dtype, | ||
| ) | ||
| # 生成有效的 top-k mask | ||
| valid_topk = (topk_values != min_dtype).to(dtype) | ||
| attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device) | ||
| attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk) | ||
| attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype) | ||
|
|
||
| # 选择 FDMA 内核 | ||
| flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") | ||
|
|
@@ -192,8 +191,8 @@ output = flash_dmattn_func( | |
| query=query, | ||
| key=key, | ||
| value=value, | ||
| attn_mask=attention_mask, | ||
| attn_bias=attention_bias, | ||
| attn_mask=attn_mask, | ||
| attn_bias=attn_bias, | ||
| is_causal=True, | ||
| softmax_scale=1.0/math.sqrt(head_dim), | ||
| ) | ||
|
|
@@ -208,13 +207,13 @@ print(f"输出形状: {output.shape}") # [1, 256, 2, 64] | |
| query.requires_grad_(True) | ||
| key.requires_grad_(True) | ||
| value.requires_grad_(True) | ||
| attention_bias.requires_grad_(True) | ||
| attn_bias.requires_grad_(True) | ||
|
|
||
| # 前向传播 | ||
| output = flash_dmattn_func( | ||
| query=query, key=key, value=value, | ||
| attn_mask=attention_mask, | ||
| attn_bias=attention_bias, | ||
| attn_mask=attn_mask, | ||
| attn_bias=attn_bias, | ||
| is_causal=True, | ||
| softmax_scale=1.0/math.sqrt(head_dim) | ||
| ) | ||
|
|
@@ -226,7 +225,7 @@ loss.backward() | |
| print(f"Query 梯度形状: {query.grad.shape}") | ||
| print(f"Key 梯度形状: {key.grad.shape}") | ||
| print(f"Value 梯度形状: {value.grad.shape}") | ||
| print(f"Bias 梯度形状: {attention_bias.grad.shape}") | ||
| print(f"Bias 梯度形状: {attn_bias.grad.shape}") | ||
| ``` | ||
|
|
||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on line 171 says 'Create bias for sparse attention' but the code creates
attn_mask, notattn_bias. Additionally,attn_biasis used on line 177 but never defined. The line should createattn_biasinstead ofattn_mask, and the subsequentcreate_maskcall should initializeattn_maskfromNoneor an appropriate default.