diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py index a5b33155..dd730190 100644 --- a/flash_dmattn/utils/mask.py +++ b/flash_dmattn/utils/mask.py @@ -17,12 +17,13 @@ import torch -def dynamic_mask( +def topk_mask( attention_bias: torch.Tensor, attention_mask: Optional[torch.Tensor], window_size: int, min_dtype: float, block_size: Optional[int] = None, + **kwargs, ): r""" This function generates a dynamic mask based on the top-k attention bias. @@ -45,10 +46,10 @@ def dynamic_mask( if int(block_size) != block_size or block_size <= 0: raise ValueError(f"block_size must be a positive integer, got {block_size}.") block_size = int(block_size) - + attention_bias = attention_bias.detach() attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias topk_values, topk_indices = torch.topk( - attention_bias.detach(), + attention_bias.to(torch.float), window_size, dim=-1, largest=True, sorted=False ) attention_mask = torch.zeros_like( @@ -77,6 +78,61 @@ def dynamic_mask( return attention_mask +def relu_mask( + attention_bias: torch.Tensor, + attention_mask: Optional[torch.Tensor], + min_dtype: float, + block_size: Optional[int] = None, + **kwargs +): + r""" + This function generates a dynamic mask based on the ReLU of attention bias. + + Args: + attention_bias (torch.Tensor): The attention bias tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + min_dtype (float): The minimum value to use for masking. + block_size (Optional[int]): Optional size of aggregation blocks to smooth the + resulting mask along the key dimension. + + Returns: + attention_mask (Tensor): The attention mask tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + """ + if block_size is not None: + if int(block_size) != block_size or block_size <= 0: + raise ValueError(f"block_size must be a positive integer, got {block_size}.") + block_size = int(block_size) + + attention_bias = attention_bias.detach() + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias + attention_mask = attention_bias > 0 + + if block_size is not None and block_size > 1: + key_len = attention_mask.shape[-1] + full_len = (key_len // block_size) * block_size + + if full_len: + block_view = attention_mask[..., :full_len] + block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) + blocks = block_view.view(*block_shape) + block_counts = blocks.sum(dim=-1).to(torch.int32) + block_keep = (block_counts * 2) > block_size + blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) + + if key_len > full_len: + tail_slice = attention_mask[..., full_len:] + tail_len = tail_slice.shape[-1] + tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) + tail_keep = (tail_counts * 2) > tail_len + tail_slice.copy_(tail_keep.expand_as(tail_slice)) + + return attention_mask + + + def create_mask( attention_bias: torch.Tensor, attention_mask: Optional[torch.Tensor], @@ -86,6 +142,7 @@ def create_mask( window_size: int, min_dtype: float, block_size: Optional[int] = None, + type: str = "topk", ) -> torch.Tensor: r""" This function creates a mask tensor for Flash Dynamic Mask Attention. @@ -103,6 +160,7 @@ def create_mask( window_size (int): The number of top elements to consider for the attention mask. min_dtype (float): The minimum value to use for masking. block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. + type (str): The type of mask to create. Options are "topk" and "relu". Returns: attention (Tensor): The attention mask tensor of shape @@ -133,12 +191,23 @@ def create_mask( ) # Generate dynamic mask based on attention_bias and attention_mask - attention_mask = dynamic_mask( - attention_bias, - attention_mask, - window_size, - min_dtype, - block_size=block_size, - ) + if type == "topk": + attention_mask = topk_mask( + attention_bias=attention_bias, + attention_mask=attention_mask, + window_size=window_size, + min_dtype=min_dtype, + block_size=block_size, + ) + elif type == "relu": + attention_mask = relu_mask( + attention_bias=attention_bias, + attention_mask=attention_mask, + window_size=window_size, + min_dtype=min_dtype, + block_size=block_size, + ) + else: + raise ValueError(f"Unsupported mask type: {type}. Supported types are 'topk' and 'relu'.") return attention_mask