1717import torch
1818
1919
20- def dynamic_mask (
20+ def topk_mask (
2121 attention_bias : torch .Tensor ,
2222 attention_mask : Optional [torch .Tensor ],
2323 window_size : int ,
2424 min_dtype : float ,
2525 block_size : Optional [int ] = None ,
26+ ** kwargs ,
2627):
2728 r"""
2829 This function generates a dynamic mask based on the top-k attention bias.
@@ -45,10 +46,10 @@ def dynamic_mask(
4546 if int (block_size ) != block_size or block_size <= 0 :
4647 raise ValueError (f"block_size must be a positive integer, got { block_size } ." )
4748 block_size = int (block_size )
48-
49+ attention_bias = attention_bias . detach ()
4950 attention_bias = attention_bias .masked_fill (~ attention_mask , min_dtype ) if attention_mask is not None else attention_bias
5051 topk_values , topk_indices = torch .topk (
51- attention_bias .detach ( ),
52+ attention_bias .to ( torch . float ),
5253 window_size , dim = - 1 , largest = True , sorted = False
5354 )
5455 attention_mask = torch .zeros_like (
@@ -77,6 +78,61 @@ def dynamic_mask(
7778 return attention_mask
7879
7980
81+ def relu_mask (
82+ attention_bias : torch .Tensor ,
83+ attention_mask : Optional [torch .Tensor ],
84+ min_dtype : float ,
85+ block_size : Optional [int ] = None ,
86+ ** kwargs
87+ ):
88+ r"""
89+ This function generates a dynamic mask based on the ReLU of attention bias.
90+
91+ Args:
92+ attention_bias (torch.Tensor): The attention bias tensor of shape
93+ ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
94+ attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
95+ ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
96+ min_dtype (float): The minimum value to use for masking.
97+ block_size (Optional[int]): Optional size of aggregation blocks to smooth the
98+ resulting mask along the key dimension.
99+
100+ Returns:
101+ attention_mask (Tensor): The attention mask tensor of shape
102+ ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
103+ """
104+ if block_size is not None :
105+ if int (block_size ) != block_size or block_size <= 0 :
106+ raise ValueError (f"block_size must be a positive integer, got { block_size } ." )
107+ block_size = int (block_size )
108+
109+ attention_bias = attention_bias .detach ()
110+ attention_bias = attention_bias .masked_fill (~ attention_mask , min_dtype ) if attention_mask is not None else attention_bias
111+ attention_mask = attention_bias > 0
112+
113+ if block_size is not None and block_size > 1 :
114+ key_len = attention_mask .shape [- 1 ]
115+ full_len = (key_len // block_size ) * block_size
116+
117+ if full_len :
118+ block_view = attention_mask [..., :full_len ]
119+ block_shape = (* block_view .shape [:- 1 ], full_len // block_size , block_size )
120+ blocks = block_view .view (* block_shape )
121+ block_counts = blocks .sum (dim = - 1 ).to (torch .int32 )
122+ block_keep = (block_counts * 2 ) > block_size
123+ blocks .copy_ (block_keep .unsqueeze (- 1 ).expand_as (blocks ))
124+
125+ if key_len > full_len :
126+ tail_slice = attention_mask [..., full_len :]
127+ tail_len = tail_slice .shape [- 1 ]
128+ tail_counts = tail_slice .sum (dim = - 1 , keepdim = True ).to (torch .int32 )
129+ tail_keep = (tail_counts * 2 ) > tail_len
130+ tail_slice .copy_ (tail_keep .expand_as (tail_slice ))
131+
132+ return attention_mask
133+
134+
135+
80136def create_mask (
81137 attention_bias : torch .Tensor ,
82138 attention_mask : Optional [torch .Tensor ],
@@ -86,6 +142,7 @@ def create_mask(
86142 window_size : int ,
87143 min_dtype : float ,
88144 block_size : Optional [int ] = None ,
145+ type : str = "topk" ,
89146) -> torch .Tensor :
90147 r"""
91148 This function creates a mask tensor for Flash Dynamic Mask Attention.
@@ -103,6 +160,7 @@ def create_mask(
103160 window_size (int): The number of top elements to consider for the attention mask.
104161 min_dtype (float): The minimum value to use for masking.
105162 block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
163+ type (str): The type of mask to create. Options are "topk" and "relu".
106164
107165 Returns:
108166 attention (Tensor): The attention mask tensor of shape
@@ -133,12 +191,23 @@ def create_mask(
133191 )
134192
135193 # Generate dynamic mask based on attention_bias and attention_mask
136- attention_mask = dynamic_mask (
137- attention_bias ,
138- attention_mask ,
139- window_size ,
140- min_dtype ,
141- block_size = block_size ,
142- )
194+ if type == "topk" :
195+ attention_mask = topk_mask (
196+ attention_bias = attention_bias ,
197+ attention_mask = attention_mask ,
198+ window_size = window_size ,
199+ min_dtype = min_dtype ,
200+ block_size = block_size ,
201+ )
202+ elif type == "relu" :
203+ attention_mask = relu_mask (
204+ attention_bias = attention_bias ,
205+ attention_mask = attention_mask ,
206+ window_size = window_size ,
207+ min_dtype = min_dtype ,
208+ block_size = block_size ,
209+ )
210+ else :
211+ raise ValueError (f"Unsupported mask type: { type } . Supported types are 'topk' and 'relu'." )
143212
144213 return attention_mask
0 commit comments