diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index a10da1e..e50a35a 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -19,6 +19,8 @@ import gc import sys +from flash_sparse_attn.utils.mask import create_mask + # Import the compiled CUDA extension try: from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func @@ -65,42 +67,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_mask( - hidden_states: torch.Tensor, - attn_bias: torch.Tensor, - causal_mask: torch.Tensor = None, - window_size: int = None, -): - """ - Args: - hidden_states: Input hidden states to determine dtype minimum value - attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) - causal_mask: Optional causal mask to apply - window_size: Window size of tokens not masked - - Returns: - tuple: (attn_bias, attn_mask) - """ - dtype = hidden_states.dtype - min_dtype = torch.finfo(dtype).min - - if attn_bias.shape[-1] > window_size: - if causal_mask is not None: - topk_values, topk_indices = torch.topk( - attn_bias.masked_fill(~causal_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - else: - topk_values, topk_indices = torch.topk( - attn_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - return attn_bias, attn_mask - - def dynamic_mask_attention_python( query_states: torch.Tensor, key_states: torch.Tensor, @@ -127,32 +93,38 @@ def dynamic_mask_attention_python( Returns: tuple: (attn_outputs, dq, dk, dv, dbias) """ - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" + ) + query_states_leaf = query_states key_states_leaf = key_states value_states_leaf = value_states - - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, - ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) # Sparse attention weight calculation attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias - attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask + if attn_mask is not None: + attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] @@ -192,16 +164,25 @@ def dynamic_mask_attention_cuda( if flash_sparse_attn_func is None: raise ImportError("CUDA implementation not available") + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + + num_queries_per_kv = num_heads // num_kv_heads + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" + ) + query_states_leaf = query_states key_states_leaf = key_states value_states_leaf = value_states - - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, - ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() @@ -259,29 +240,28 @@ def dynamic_mask_attention_triton( if triton_sparse_attn_func is None: raise RuntimeError("Triton implementation not available") - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" + ) + query_states_leaf = query_states key_states_leaf = key_states value_states_leaf = value_states - - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, - ) attn_bias_leaf = attn_bias attn_bias_leaf.retain_grad() - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) - # Ensure correct data types and memory layout for Triton function query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim] @@ -333,30 +313,38 @@ def dynamic_mask_attention_flex( if flex_sparse_attn_func is None: raise RuntimeError("Flex Attention implementation not available") - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) - attn_bias.retain_grad() + + query_states_leaf = query_states + key_states_leaf = key_states + value_states_leaf = value_states + attn_bias_leaf = attn_bias + attn_bias_leaf.retain_grad() # Repeat KV for multi-head attention (GQA support) key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None attn_bias = repeat_kv(attn_bias, num_queries_per_kv) # Ensure correct data types and memory layout for Flex function query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Flex Attention implementation attn_outputs = flex_sparse_attn_func( @@ -372,7 +360,7 @@ def dynamic_mask_attention_flex( # Backward pass attn_outputs.sum().backward() - return attn_outputs, query_states.grad, key_states.grad, value_states.grad, attn_bias.grad + return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad def analyze_differences(original_result, cuda_result, accuracy_threshold=0.95): @@ -609,7 +597,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): device=device, dtype=dtype, requires_grad=True ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) @@ -843,7 +831,7 @@ def test_triton_backward_equivalence(accuracy_threshold=0.95): device=device, dtype=dtype, requires_grad=True ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) diff --git a/benchmarks/backward_performance.py b/benchmarks/backward_performance.py index 59daf16..68aebbc 100644 --- a/benchmarks/backward_performance.py +++ b/benchmarks/backward_performance.py @@ -26,6 +26,8 @@ import gc import sys +from flash_sparse_attn.utils.mask import create_mask + # Import the compiled CUDA extension try: from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func @@ -72,42 +74,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_mask( - hidden_states: torch.Tensor, - attn_bias: torch.Tensor, - causal_mask: torch.Tensor = None, - window_size: int = None, -): - """ - Args: - hidden_states: Input hidden states to determine dtype minimum value - attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) - causal_mask: Optional causal mask to apply - window_size: Window size of tokens not masked - - Returns: - tuple: (attn_bias, attn_mask) - """ - dtype = hidden_states.dtype - min_dtype = torch.finfo(dtype).min - - if attn_bias.shape[-1] > window_size: - if causal_mask is not None: - topk_values, topk_indices = torch.topk( - attn_bias.masked_fill(~causal_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - else: - topk_values, topk_indices = torch.topk( - attn_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - return attn_bias, attn_mask - - def scaled_dot_product_attention_backward_cuda( query_states: torch.Tensor, key_states: torch.Tensor, @@ -133,15 +99,20 @@ def scaled_dot_product_attention_backward_cuda( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -210,11 +181,20 @@ def dynamic_mask_attention_backward_cuda( if flash_sparse_attn_func is None: return "Not Available", 0 - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + + num_queries_per_kv = num_heads // num_kv_heads + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Ensure correct data types and memory layout for CUDA function @@ -280,29 +260,27 @@ def dynamic_mask_attention_backward_triton( if triton_sparse_attn_func is None: return "Not Available", 0 - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + block_size=64, + type="topk" ) - # Repeat KV for multi-head attention (GQA support) - key_states = repeat_kv(key_states, num_queries_per_kv) - value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) - attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - # Ensure correct data types and memory layout for Triton function query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] try: attn_outputs = triton_sparse_attn_func( @@ -359,29 +337,32 @@ def dynamic_mask_attention_backward_flex( if flex_sparse_attn_func is None: return "Not Available", 0 - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None attn_bias = repeat_kv(attn_bias, num_queries_per_kv) # Ensure correct data types and memory layout for Flex function query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] try: attn_outputs = flex_sparse_attn_func( @@ -453,7 +434,7 @@ def benchmark_backward_attention_performance(config, test_type='all', num_runs=5 device=device, dtype=torch.bfloat16, requires_grad=True ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index 9b05ba3..aebb2de 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -19,6 +19,8 @@ import gc import sys +from flash_sparse_attn.utils.mask import create_mask + # Import the compiled CUDA extension try: from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func @@ -65,42 +67,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_mask( - hidden_states: torch.Tensor, - attn_bias: torch.Tensor, - causal_mask: torch.Tensor = None, - window_size: int = None, -): - """ - Args: - hidden_states: Input hidden states to determine dtype minimum value - attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) - causal_mask: Optional causal mask to apply - window_size: Window size of tokens not masked - - Returns: - tuple: (attn_bias, attn_mask) - """ - dtype = hidden_states.dtype - min_dtype = torch.finfo(dtype).min - - if attn_bias.shape[-1] > window_size: - if causal_mask is not None: - topk_values, topk_indices = torch.topk( - attn_bias.masked_fill(~causal_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - else: - topk_values, topk_indices = torch.topk( - attn_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - return attn_bias, attn_mask - - def dynamic_mask_attention_python( query_states: torch.Tensor, key_states: torch.Tensor, @@ -127,27 +93,32 @@ def dynamic_mask_attention_python( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) key_states = repeat_kv(key_states, num_queries_per_kv) value_states = repeat_kv(value_states, num_queries_per_kv) attn_bias = repeat_kv(attn_bias, num_queries_per_kv) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None # Sparse attention weight calculation attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # Dot product weights attn_weights = attn_weights * scaling + attn_bias # Apply scaling and bias - attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask + if attn_mask is not None: + attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf')) # Apply mask attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization attn_outputs = torch.matmul(attn_weights, value_states) # Weighted sum of values attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] @@ -184,11 +155,20 @@ def dynamic_mask_attention_cuda( if flash_sparse_attn_func is None: raise RuntimeError("flash_sparse_attn_func not available") - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + + num_queries_per_kv = num_heads // num_kv_heads + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Ensure correct data types and memory layout for CUDA function @@ -242,15 +222,20 @@ def dynamic_mask_attention_triton( if triton_sparse_attn_func is None: raise RuntimeError("Triton implementation not available") - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -309,15 +294,20 @@ def dynamic_mask_attention_flex( if flex_sparse_attn_func is None: raise RuntimeError("Flex Attention implementation not available") - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -580,7 +570,7 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): device=device, dtype=torch.bfloat16 ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) @@ -768,7 +758,7 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): device=device, dtype=torch.bfloat16 ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) @@ -973,7 +963,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): device=device, dtype=torch.bfloat16 ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device) diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 05e75c4..f93537b 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -26,6 +26,8 @@ import time import gc +from flash_sparse_attn.utils.mask import create_mask + # Import the compiled CUDA extension try: from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func @@ -72,42 +74,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def prepare_mask( - hidden_states: torch.Tensor, - attn_bias: torch.Tensor, - causal_mask: torch.Tensor = None, - window_size: int = None, -): - """ - Args: - hidden_states: Input hidden states to determine dtype minimum value - attn_bias: Attention bias of shape (batch_size, num_heads, query_length, key_length) - causal_mask: Optional causal mask to apply - window_size: Window size of tokens not masked - - Returns: - tuple: (attn_bias, attn_mask) - """ - dtype = hidden_states.dtype - min_dtype = torch.finfo(dtype).min - - if attn_bias.shape[-1] > window_size: - if causal_mask is not None: - topk_values, topk_indices = torch.topk( - attn_bias.masked_fill(~causal_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - else: - topk_values, topk_indices = torch.topk( - attn_bias, - window_size, dim=-1, largest=True, sorted=False - ) - attn_mask = torch.zeros_like(attn_bias, dtype=torch.bool, device=attn_bias.device).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - attn_mask = causal_mask.expand_as(attn_bias) if causal_mask is not None else torch.ones_like(attn_bias, dtype=torch.bool, device=attn_bias.device) - return attn_bias, attn_mask - - def scaled_dot_product_attention_cuda( query_states: torch.Tensor, key_states: torch.Tensor, @@ -134,19 +100,24 @@ def scaled_dot_product_attention_cuda( Returns: tuple: (output_tensor, timing_ms) or ("OOM", 0) or ("Not Available", 0) """ - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) - attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) if attn_mask is not None else None attn_bias = repeat_kv(attn_bias, num_queries_per_kv) query_states = query_states.contiguous() @@ -167,7 +138,7 @@ def scaled_dot_product_attention_cuda( # is_causal=is_causal, enable_gqa=True, ) - + torch.cuda.synchronize() end_time = time.time() @@ -206,11 +177,21 @@ def dynamic_mask_attention_cuda( if flash_sparse_attn_func is None: return "Not Available", 0 - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + + num_queries_per_kv = num_heads // num_kv_heads + + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + block_size=64, + type="topk" ) # Ensure correct data types and memory layout for CUDA function @@ -272,15 +253,20 @@ def dynamic_mask_attention_triton( if triton_sparse_attn_func is None: return "Not Available", 0 - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -347,15 +333,20 @@ def dynamic_mask_attention_flex( if flex_sparse_attn_func is None: return "Not Available", 0 - _, num_heads, _, _ = query_states.shape - _, num_kv_heads, _, _ = key_states.shape + batch_size, num_heads, query_len, _ = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads - attn_bias, attn_mask = prepare_mask( - query_states, - attn_bias, - causal_mask if is_causal else None, - window_size, + attn_mask = create_mask( + attention_bias=attn_bias, + attention_mask=causal_mask if is_causal else None, + batch_size=batch_size, + query_len=query_len, + key_len=key_len, + window_size=window_size, + min_dtype=torch.finfo(query_states.dtype).min, + type="topk" ) # Repeat KV for multi-head attention (GQA support) @@ -437,7 +428,7 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ device=device, dtype=torch.bfloat16 ) attn_bias = torch.randn( - batch_size, num_kv_heads, query_len, key_len, + batch_size, num_kv_heads, 1, key_len, device=device, dtype=torch.bfloat16 ) cache_position = torch.arange(key_len - query_len, key_len, device=device)