@@ -163,14 +163,12 @@ def relu_mask(
163163
164164def create_mask (
165165 attention_bias : torch .Tensor ,
166- attention_mask : Optional [torch .Tensor ],
167- batch_size : int ,
168166 query_len : int ,
169- key_len : int ,
170- window_size : Optional [int ],
171- min_dtype : Optional [float ],
172- block_size : Optional [int ],
173167 type : str = "topk" ,
168+ attention_mask : Optional [torch .Tensor ] = None ,
169+ window_size : Optional [int ] = None ,
170+ min_dtype : Optional [float ] = None ,
171+ block_size : Optional [int ] = None ,
174172) -> torch .Tensor :
175173 r"""
176174 This function creates a mask tensor for Flash Sparse Attention.
@@ -180,15 +178,13 @@ def create_mask(
180178 Args:
181179 attention_bias (torch.Tensor): The attention bias tensor of shape
182180 ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
181+ query_len (int): The sequence length of the query.
182+ type (str): The type of mask to create. Options are "topk" and "relu".
183183 attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
184184 (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
185- batch_size (int): The batch size.
186- query_len (int): The sequence length of the query.
187- key_len (int): The sequence length of the key.
188185 window_size (Optional[int]): The number of top elements to consider for the attention mask.
189186 min_dtype (Optional[float]): The minimum value to use for masking.
190187 block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
191- type (str): The type of mask to create. Options are "topk" and "relu".
192188
193189 Returns:
194190 attention (Tensor): The attention mask tensor of shape
@@ -200,6 +196,7 @@ def create_mask(
200196
201197 # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
202198 if attention_mask is not None and attention_mask .dim () == 2 :
199+ batch_size , key_len = attention_bias .shape [0 ], attention_bias .shape [- 1 ]
203200 if attention_mask .shape [- 1 ] == key_len :
204201 attention_mask = attention_mask .view (batch_size , 1 , 1 , key_len )
205202 elif attention_mask .shape [- 1 ] == query_len :
0 commit comments