Skip to content

Commit 111e224

Browse files
committed
Adds ReLU mask to attention utils
Introduces selectable masking strategies to support both top-k and ReLU flows, enabling experimentation with bias-driven sparsity. Normalizes the top-k path to use detached bias casting and rejects unsupported mask types to avoid silent misuse.
1 parent 1dcc395 commit 111e224

File tree

1 file changed

+79
-10
lines changed

1 file changed

+79
-10
lines changed

flash_dmattn/utils/mask.py

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
import torch
1818

1919

20-
def dynamic_mask(
20+
def topk_mask(
2121
attention_bias: torch.Tensor,
2222
attention_mask: Optional[torch.Tensor],
2323
window_size: int,
2424
min_dtype: float,
2525
block_size: Optional[int] = None,
26+
**kwargs,
2627
):
2728
r"""
2829
This function generates a dynamic mask based on the top-k attention bias.
@@ -45,10 +46,10 @@ def dynamic_mask(
4546
if int(block_size) != block_size or block_size <= 0:
4647
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
4748
block_size = int(block_size)
48-
49+
attention_bias = attention_bias.detach()
4950
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
5051
topk_values, topk_indices = torch.topk(
51-
attention_bias.detach(),
52+
attention_bias.to(torch.float),
5253
window_size, dim=-1, largest=True, sorted=False
5354
)
5455
attention_mask = torch.zeros_like(
@@ -77,6 +78,61 @@ def dynamic_mask(
7778
return attention_mask
7879

7980

81+
def relu_mask(
82+
attention_bias: torch.Tensor,
83+
attention_mask: Optional[torch.Tensor],
84+
min_dtype: float,
85+
block_size: Optional[int] = None,
86+
**kwargs
87+
):
88+
r"""
89+
This function generates a dynamic mask based on the ReLU of attention bias.
90+
91+
Args:
92+
attention_bias (torch.Tensor): The attention bias tensor of shape
93+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
94+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
95+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
96+
min_dtype (float): The minimum value to use for masking.
97+
block_size (Optional[int]): Optional size of aggregation blocks to smooth the
98+
resulting mask along the key dimension.
99+
100+
Returns:
101+
attention_mask (Tensor): The attention mask tensor of shape
102+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
103+
"""
104+
if block_size is not None:
105+
if int(block_size) != block_size or block_size <= 0:
106+
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
107+
block_size = int(block_size)
108+
109+
attention_bias = attention_bias.detach()
110+
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
111+
attention_mask = attention_bias > 0
112+
113+
if block_size is not None and block_size > 1:
114+
key_len = attention_mask.shape[-1]
115+
full_len = (key_len // block_size) * block_size
116+
117+
if full_len:
118+
block_view = attention_mask[..., :full_len]
119+
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
120+
blocks = block_view.view(*block_shape)
121+
block_counts = blocks.sum(dim=-1).to(torch.int32)
122+
block_keep = (block_counts * 2) > block_size
123+
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
124+
125+
if key_len > full_len:
126+
tail_slice = attention_mask[..., full_len:]
127+
tail_len = tail_slice.shape[-1]
128+
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
129+
tail_keep = (tail_counts * 2) > tail_len
130+
tail_slice.copy_(tail_keep.expand_as(tail_slice))
131+
132+
return attention_mask
133+
134+
135+
80136
def create_mask(
81137
attention_bias: torch.Tensor,
82138
attention_mask: Optional[torch.Tensor],
@@ -86,6 +142,7 @@ def create_mask(
86142
window_size: int,
87143
min_dtype: float,
88144
block_size: Optional[int] = None,
145+
type: str = "topk",
89146
) -> torch.Tensor:
90147
r"""
91148
This function creates a mask tensor for Flash Dynamic Mask Attention.
@@ -103,6 +160,7 @@ def create_mask(
103160
window_size (int): The number of top elements to consider for the attention mask.
104161
min_dtype (float): The minimum value to use for masking.
105162
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
163+
type (str): The type of mask to create. Options are "topk" and "relu".
106164
107165
Returns:
108166
attention (Tensor): The attention mask tensor of shape
@@ -133,12 +191,23 @@ def create_mask(
133191
)
134192

135193
# Generate dynamic mask based on attention_bias and attention_mask
136-
attention_mask = dynamic_mask(
137-
attention_bias,
138-
attention_mask,
139-
window_size,
140-
min_dtype,
141-
block_size=block_size,
142-
)
194+
if type == "topk":
195+
attention_mask = topk_mask(
196+
attention_bias=attention_bias,
197+
attention_mask=attention_mask,
198+
window_size=window_size,
199+
min_dtype=min_dtype,
200+
block_size=block_size,
201+
)
202+
elif type == "relu":
203+
attention_mask = relu_mask(
204+
attention_bias=attention_bias,
205+
attention_mask=attention_mask,
206+
window_size=window_size,
207+
min_dtype=min_dtype,
208+
block_size=block_size,
209+
)
210+
else:
211+
raise ValueError(f"Unsupported mask type: {type}. Supported types are 'topk' and 'relu'.")
143212

144213
return attention_mask

0 commit comments

Comments
 (0)