1717import torch
1818
1919
20+ def block_smooth (
21+ attention_mask : torch .Tensor ,
22+ key_len : int ,
23+ block_size : int ,
24+ ):
25+ if block_size <= 0 :
26+ raise ValueError (f"block_size must be a positive integer, got { block_size } ." )
27+
28+ if block_size > 1 :
29+ full_len = (key_len // block_size ) * block_size
30+
31+ if full_len :
32+ block_view = attention_mask [..., :full_len ]
33+ block_shape = (* block_view .shape [:- 1 ], full_len // block_size , block_size )
34+ blocks = block_view .view (* block_shape )
35+ block_counts = blocks .sum (dim = - 1 ).to (torch .int64 )
36+ block_keep = (block_counts * 2 ) > block_size
37+ blocks .copy_ (block_keep .unsqueeze (- 1 ).expand_as (blocks ))
38+
39+ if key_len > full_len :
40+ tail_slice = attention_mask [..., full_len :]
41+ tail_len = tail_slice .shape [- 1 ]
42+ tail_counts = tail_slice .sum (dim = - 1 , keepdim = True ).to (torch .int64 )
43+ tail_keep = (tail_counts * 2 ) > tail_len
44+ tail_slice .copy_ (tail_keep .expand_as (tail_slice ))
45+
46+ return attention_mask
47+
48+
2049def topk_mask (
2150 attention_bias : torch .Tensor ,
2251 attention_mask : Optional [torch .Tensor ],
@@ -42,10 +71,7 @@ def topk_mask(
4271 attention_mask (Tensor): The attention mask tensor of shape
4372 ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
4473 """
45- if block_size is not None :
46- if int (block_size ) != block_size or block_size <= 0 :
47- raise ValueError (f"block_size must be a positive integer, got { block_size } ." )
48- block_size = int (block_size )
74+
4975 attention_bias = attention_bias .detach ()
5076 attention_bias = attention_bias .masked_fill (~ attention_mask , min_dtype ) if attention_mask is not None else attention_bias
5177 topk_values , topk_indices = torch .topk (
@@ -58,22 +84,11 @@ def topk_mask(
5884
5985 if block_size is not None and block_size > 1 :
6086 key_len = attention_mask .shape [- 1 ]
61- full_len = (key_len // block_size ) * block_size
62-
63- if full_len :
64- block_view = attention_mask [..., :full_len ]
65- block_shape = (* block_view .shape [:- 1 ], full_len // block_size , block_size )
66- blocks = block_view .view (* block_shape )
67- block_counts = blocks .sum (dim = - 1 ).to (torch .int32 )
68- block_keep = (block_counts * 2 ) > block_size
69- blocks .copy_ (block_keep .unsqueeze (- 1 ).expand_as (blocks ))
70-
71- if key_len > full_len :
72- tail_slice = attention_mask [..., full_len :]
73- tail_len = tail_slice .shape [- 1 ]
74- tail_counts = tail_slice .sum (dim = - 1 , keepdim = True ).to (torch .int32 )
75- tail_keep = (tail_counts * 2 ) > tail_len
76- tail_slice .copy_ (tail_keep .expand_as (tail_slice ))
87+ attention_mask = block_smooth (
88+ attention_mask = attention_mask ,
89+ key_len = key_len ,
90+ block_size = block_size
91+ )
7792
7893 return attention_mask
7994
@@ -101,33 +116,18 @@ def relu_mask(
101116 attention_mask (Tensor): The attention mask tensor of shape
102117 ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
103118 """
104- if block_size is not None :
105- if int (block_size ) != block_size or block_size <= 0 :
106- raise ValueError (f"block_size must be a positive integer, got { block_size } ." )
107- block_size = int (block_size )
108-
119+
109120 attention_bias = attention_bias .detach ()
110121 attention_bias = attention_bias .masked_fill (~ attention_mask , min_dtype ) if attention_mask is not None else attention_bias
111122 attention_mask = attention_bias > 0
112123
113124 if block_size is not None and block_size > 1 :
114125 key_len = attention_mask .shape [- 1 ]
115- full_len = (key_len // block_size ) * block_size
116-
117- if full_len :
118- block_view = attention_mask [..., :full_len ]
119- block_shape = (* block_view .shape [:- 1 ], full_len // block_size , block_size )
120- blocks = block_view .view (* block_shape )
121- block_counts = blocks .sum (dim = - 1 ).to (torch .int32 )
122- block_keep = (block_counts * 2 ) > block_size
123- blocks .copy_ (block_keep .unsqueeze (- 1 ).expand_as (blocks ))
124-
125- if key_len > full_len :
126- tail_slice = attention_mask [..., full_len :]
127- tail_len = tail_slice .shape [- 1 ]
128- tail_counts = tail_slice .sum (dim = - 1 , keepdim = True ).to (torch .int32 )
129- tail_keep = (tail_counts * 2 ) > tail_len
130- tail_slice .copy_ (tail_keep .expand_as (tail_slice ))
126+ attention_mask = block_smooth (
127+ attention_mask = attention_mask ,
128+ key_len = key_len ,
129+ block_size = block_size
130+ )
131131
132132 return attention_mask
133133
0 commit comments