Skip to content

Commit 9b40177

Browse files
committed
Optimizes attention mask handling with boolean dtype
Replaces float-based attention mask operations with boolean dtype for improved memory efficiency and cleaner logic. Removes unnecessary dtype conversion and simplifies mask creation by using boolean tensors directly instead of converting comparison results to float values.
1 parent a0d6ee5 commit 9b40177

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,15 @@ def _flash_dynamic_mask_attention_forward(
9393
~attention_mask,
9494
min_dtype
9595
)
96-
attention_mask = attention_mask.to(dtype)
9796

9897
if keep_window_size is not None:
9998
if key_length > keep_window_size:
10099
topk_values, topk_indices = torch.topk(
101100
attention_bias, keep_window_size, dim=-1, largest=True, sorted=False
102101
)
103-
valid_topk = (topk_values != min_dtype).to(dtype)
104-
attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device)
105-
attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk)
106-
attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype)
102+
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
103+
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
104+
attention_bias = attention_bias.masked_fill(attention_mask == False, min_dtype)
107105

108106
out = flash_fn(
109107
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal

0 commit comments

Comments
 (0)