Skip to content

Commit 14a7f1c

Browse files
committed
Adds helper for top-k attention indices
Introduces reusable top-k extraction on the bias tensor to simplify downstream mask logic.
1 parent c4401e0 commit 14a7f1c

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

flash_dmattn/utils/mask.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,33 @@
1717
import torch
1818

1919

20+
def topk_indices(
21+
attention_bias: torch.Tensor,
22+
window_size: int,
23+
**kwargs,
24+
) -> torch.Tensor:
25+
r"""
26+
This function generates top-k indices based on the attention bias.
27+
28+
Args:
29+
attention_bias (torch.Tensor): The attention bias tensor of
30+
(batch_size, num_kv_heads, key_len).
31+
window_size (int): The number of top elements to consider for the mask.
32+
**kwargs: Additional keyword arguments.
33+
34+
Returns:
35+
topk_indices (Tensor): The top-k indices tensor of shape
36+
(batch_size, num_kv_heads, window_size).
37+
"""
38+
attention_bias = attention_bias.detach()
39+
topk_indices = torch.topk(
40+
attention_bias,
41+
window_size, dim=-1, largest=True, sorted=False
42+
).indices
43+
topk_indices = torch.sort(topk_indices, dim=-1).values
44+
return topk_indices
45+
46+
2047
def dynamic_mask(
2148
attention_bias: torch.Tensor,
2249
attention_mask: Optional[torch.Tensor],

0 commit comments

Comments
 (0)