File tree Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change 1717import torch
1818
1919
20+ def topk_indices (
21+ attention_bias : torch .Tensor ,
22+ window_size : int ,
23+ ** kwargs ,
24+ ) -> torch .Tensor :
25+ r"""
26+ This function generates top-k indices based on the attention bias.
27+
28+ Args:
29+ attention_bias (torch.Tensor): The attention bias tensor of
30+ (batch_size, num_kv_heads, key_len).
31+ window_size (int): The number of top elements to consider for the mask.
32+ **kwargs: Additional keyword arguments.
33+
34+ Returns:
35+ topk_indices (Tensor): The top-k indices tensor of shape
36+ (batch_size, num_kv_heads, window_size).
37+ """
38+ attention_bias = attention_bias .detach ()
39+ topk_indices = torch .topk (
40+ attention_bias ,
41+ window_size , dim = - 1 , largest = True , sorted = False
42+ ).indices
43+ topk_indices = torch .sort (topk_indices , dim = - 1 ).values
44+ return topk_indices
45+
46+
2047def dynamic_mask (
2148 attention_bias : torch .Tensor ,
2249 attention_mask : Optional [torch .Tensor ],
You can’t perform that action at this time.
0 commit comments