@@ -22,6 +22,7 @@ def dynamic_mask(
2222 attention_mask : Optional [torch .Tensor ],
2323 window_size : int ,
2424 min_dtype : float ,
25+ block_size : Optional [int ] = None ,
2526):
2627 r"""
2728 This function generates a dynamic mask based on the top-k attention bias.
@@ -33,11 +34,18 @@ def dynamic_mask(
3334 ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
3435 window_size (int): The number of top elements to consider for the mask.
3536 min_dtype (float): The minimum value to use for masking.
37+ block_size (Optional[int]): Optional size of aggregation blocks to smooth the
38+ resulting mask along the key dimension.
3639
3740 Returns:
3841 attention_mask (Tensor): The attention mask tensor of shape
3942 ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
4043 """
44+ if block_size is not None :
45+ if int (block_size ) != block_size or block_size <= 0 :
46+ raise ValueError (f"block_size must be a positive integer, got { block_size } ." )
47+ block_size = int (block_size )
48+
4149 attention_bias = attention_bias .masked_fill (~ attention_mask , min_dtype ) if attention_mask is not None else attention_bias
4250 topk_values , topk_indices = torch .topk (
4351 attention_bias .detach (),
@@ -46,6 +54,26 @@ def dynamic_mask(
4654 attention_mask = torch .zeros_like (
4755 attention_bias , dtype = torch .bool , device = attention_bias .device
4856 ).scatter_ (- 1 , topk_indices , topk_values != min_dtype )
57+
58+ if block_size is not None and block_size > 1 :
59+ key_len = attention_mask .shape [- 1 ]
60+ full_len = (key_len // block_size ) * block_size
61+
62+ if full_len :
63+ block_view = attention_mask [..., :full_len ]
64+ block_shape = (* block_view .shape [:- 1 ], full_len // block_size , block_size )
65+ blocks = block_view .view (* block_shape )
66+ block_counts = blocks .sum (dim = - 1 ).to (torch .int32 )
67+ block_keep = (block_counts * 2 ) > block_size
68+ blocks .copy_ (block_keep .unsqueeze (- 1 ).expand_as (blocks ))
69+
70+ if key_len > full_len :
71+ tail_slice = attention_mask [..., full_len :]
72+ tail_len = tail_slice .shape [- 1 ]
73+ tail_counts = tail_slice .sum (dim = - 1 , keepdim = True ).to (torch .int32 )
74+ tail_keep = (tail_counts * 2 ) > tail_len
75+ tail_slice .copy_ (tail_keep .expand_as (tail_slice ))
76+
4977 return attention_mask
5078
5179
@@ -57,6 +85,7 @@ def create_mask(
5785 key_len : int ,
5886 window_size : int ,
5987 min_dtype : float ,
88+ block_size : Optional [int ] = None ,
6089) -> torch .Tensor :
6190 r"""
6291 This function creates a mask tensor for Flash Dynamic Mask Attention.
@@ -73,6 +102,7 @@ def create_mask(
73102 key_len (int): The sequence length of the key.
74103 window_size (int): The number of top elements to consider for the attention mask.
75104 min_dtype (float): The minimum value to use for masking.
105+ block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.
76106
77107 Returns:
78108 attention (Tensor): The attention mask tensor of shape
@@ -103,6 +133,12 @@ def create_mask(
103133 )
104134
105135 # Generate dynamic mask based on attention_bias and attention_mask
106- attention_mask = dynamic_mask (attention_bias , attention_mask , window_size , min_dtype )
136+ attention_mask = dynamic_mask (
137+ attention_bias ,
138+ attention_mask ,
139+ window_size ,
140+ min_dtype ,
141+ block_size = block_size ,
142+ )
107143
108144 return attention_mask
0 commit comments