Skip to content

Commit e3bcf48

Browse files
authored
Merge pull request #205 from SmallDoges/add-mask-type
Refactor attention block smoothing for consistency
2 parents 65c06cd + 109c7ad commit e3bcf48

File tree

1 file changed

+41
-41
lines changed

1 file changed

+41
-41
lines changed

flash_dmattn/utils/mask.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,35 @@
1717
import torch
1818

1919

20+
def block_smooth(
21+
attention_mask: torch.Tensor,
22+
key_len: int,
23+
block_size: int,
24+
):
25+
if block_size <= 0:
26+
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
27+
28+
if block_size > 1:
29+
full_len = (key_len // block_size) * block_size
30+
31+
if full_len:
32+
block_view = attention_mask[..., :full_len]
33+
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
34+
blocks = block_view.view(*block_shape)
35+
block_counts = blocks.sum(dim=-1).to(torch.int64)
36+
block_keep = (block_counts * 2) > block_size
37+
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
38+
39+
if key_len > full_len:
40+
tail_slice = attention_mask[..., full_len:]
41+
tail_len = tail_slice.shape[-1]
42+
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64)
43+
tail_keep = (tail_counts * 2) > tail_len
44+
tail_slice.copy_(tail_keep.expand_as(tail_slice))
45+
46+
return attention_mask
47+
48+
2049
def topk_mask(
2150
attention_bias: torch.Tensor,
2251
attention_mask: Optional[torch.Tensor],
@@ -42,10 +71,7 @@ def topk_mask(
4271
attention_mask (Tensor): The attention mask tensor of shape
4372
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
4473
"""
45-
if block_size is not None:
46-
if int(block_size) != block_size or block_size <= 0:
47-
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
48-
block_size = int(block_size)
74+
4975
attention_bias = attention_bias.detach()
5076
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
5177
topk_values, topk_indices = torch.topk(
@@ -58,22 +84,11 @@ def topk_mask(
5884

5985
if block_size is not None and block_size > 1:
6086
key_len = attention_mask.shape[-1]
61-
full_len = (key_len // block_size) * block_size
62-
63-
if full_len:
64-
block_view = attention_mask[..., :full_len]
65-
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
66-
blocks = block_view.view(*block_shape)
67-
block_counts = blocks.sum(dim=-1).to(torch.int32)
68-
block_keep = (block_counts * 2) > block_size
69-
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
70-
71-
if key_len > full_len:
72-
tail_slice = attention_mask[..., full_len:]
73-
tail_len = tail_slice.shape[-1]
74-
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
75-
tail_keep = (tail_counts * 2) > tail_len
76-
tail_slice.copy_(tail_keep.expand_as(tail_slice))
87+
attention_mask = block_smooth(
88+
attention_mask=attention_mask,
89+
key_len=key_len,
90+
block_size=block_size
91+
)
7792

7893
return attention_mask
7994

@@ -101,33 +116,18 @@ def relu_mask(
101116
attention_mask (Tensor): The attention mask tensor of shape
102117
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
103118
"""
104-
if block_size is not None:
105-
if int(block_size) != block_size or block_size <= 0:
106-
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
107-
block_size = int(block_size)
108-
119+
109120
attention_bias = attention_bias.detach()
110121
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
111122
attention_mask = attention_bias > 0
112123

113124
if block_size is not None and block_size > 1:
114125
key_len = attention_mask.shape[-1]
115-
full_len = (key_len // block_size) * block_size
116-
117-
if full_len:
118-
block_view = attention_mask[..., :full_len]
119-
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
120-
blocks = block_view.view(*block_shape)
121-
block_counts = blocks.sum(dim=-1).to(torch.int32)
122-
block_keep = (block_counts * 2) > block_size
123-
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
124-
125-
if key_len > full_len:
126-
tail_slice = attention_mask[..., full_len:]
127-
tail_len = tail_slice.shape[-1]
128-
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
129-
tail_keep = (tail_counts * 2) > tail_len
130-
tail_slice.copy_(tail_keep.expand_as(tail_slice))
126+
attention_mask = block_smooth(
127+
attention_mask=attention_mask,
128+
key_len=key_len,
129+
block_size=block_size
130+
)
131131

132132
return attention_mask
133133

0 commit comments

Comments
 (0)