Skip to content

Commit 623b75d

Browse files
committed
Remove redundant dimension checks for attention_mask and attention_bias in _flash_dynamic_mask_attention_forward
1 parent 42e118d commit 623b75d

File tree

1 file changed

+0
-5
lines changed

1 file changed

+0
-5
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -626,11 +626,6 @@ def _flash_dynamic_mask_attention_forward(
626626
):
627627
min_dtype = torch.finfo(query_states.dtype).min
628628
if attention_mask is not None:
629-
if attention_mask.dim() == 4 and attention_bias.dim() == 3:
630-
attention_bias = attention_bias.unsqueeze(-2).expand(-1, -1, query_length, -1)
631-
if attention_mask.dim() == 3 and attention_bias.dim() == 4:
632-
attention_mask = attention_mask.unsqueeze(-2).expand(-1, -1, query_length, -1)
633-
634629
topk_values, topk_indices = torch.topk(
635630
attention_bias.masked_fill(~attention_mask, min_dtype).detach(),
636631
window_size, dim=-1, largest=True, sorted=False

0 commit comments

Comments
 (0)