-
Notifications
You must be signed in to change notification settings - Fork 49
Simplify attention mechanisms #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Removes rotary embedding utilities and associated parameters so the module only consumes projections required by flash attention. Switches cache handling to simple tuple concatenation and aligns tensor reshapes with flash_attn expectations while dropping unused attention weights.
Removes RoPE-specific helpers and cache plumbing so the module works with plain tensors. Introduces new gate/delta projections and sigmoid-based bias to streamline mask creation. Drops the unused attention weight return to match flash sparse expectations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR aims to simplify the multi-head attention and dynamic mask attention mechanisms by removing rotary position embeddings and streamlining cache handling. However, the implementation contains several critical issues that need to be addressed.
Key Changes:
- Removed rotary embedding utility functions and position embedding parameters
- Simplified cache handling from Cache objects to simple tuple concatenation
- Renamed projection layers (a_proj → g_proj, dt_proj → d_proj) and changed attention bias calculation from mean + softplus to sigmoid
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| flash_sparse_attn/modules/multi_head_attention.py | Removed rotary embeddings, simplified cache handling to tuples, and removed attn_weights return value |
| flash_sparse_attn/modules/dynamic_mask_attention.py | Same changes as multi_head_attention.py plus renamed projection layers and simplified attention bias calculation |
Critical Issues Found:
- API Breaking Changes: Both attention modules now return only the attention output instead of the documented tuple of three values, contradicting the PR's backward compatibility claim
- Missing Parameters: The create_mask function call is missing required parameters (attention_mask, batch_size, key_len, window_size, min_dtype, block_size)
- Dimension Mismatch: The attention bias calculation attempts to multiply tensors with incompatible dimensions (seq_len vs key_len)
- Incorrect Indexing: Using query_states.shape[2] after reshaping returns the number of heads instead of query_len
- Broken Caching: The updated past_key_values tuple is not returned, breaking KV caching for generation scenarios
These issues will cause runtime errors and prevent the code from functioning correctly. The PR needs significant revisions before it can be merged.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from typing import Optional, Tuple | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Tuple type has been imported but the Cache import from transformers.cache_utils is still present. Since the Cache object is no longer used in the simplified implementation (past_key_values now uses a simple Tuple), the Cache import should be removed.
| attention_bias=attn_bias, | ||
| attention_mask=attention_mask, | ||
| batch_size=query_states.shape[0], | ||
| query_len=query_states.shape[2], |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing query_states.shape[2] is incorrect. At this point in the code, query_states has been reshaped at line 72 to have dimensions (bsz, seq_len, num_heads, head_dim), so shape[2] refers to the number of heads, not query_len. The query_len should be obtained from shape[1] which is seq_len, or simply use the seq_len variable that's already defined.
| query_len=query_states.shape[2], | |
| query_len=seq_len, |
|
|
||
| query_states = self.q_proj(hidden_states) | ||
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| if past_key_values is not None: | ||
| # sin and cos are specific to RoPE models; cache_position needed for the static cache | ||
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | ||
| key_states, value_states = past_key_values.update( | ||
| key_states, value_states, self.layer_idx, cache_kwargs | ||
| ) | ||
| # sampling a_states from query_states | ||
| a_states = self.a_proj( | ||
| query_states.transpose(1, 2).reshape( | ||
| query_states.shape[0], query_states.shape[-2], -1 | ||
| ) | ||
| ) # [batch_size, query_len, num_key_value_heads] | ||
| # sampling dt_states from value_states | ||
| dt_states = self.dt_proj( | ||
| value_states.transpose(1, 2).reshape( | ||
| value_states.shape[0], value_states.shape[-2], -1 | ||
| ) | ||
| ) # [batch_size, key_len, num_key_value_heads] | ||
| # original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V) | ||
| past_key, past_value = past_key_values | ||
| key_states = torch.cat([past_key, key_states], dim=1) | ||
| value_states = torch.cat([past_value, value_states], dim=1) | ||
| key_len = key_states.size(1) | ||
|
|
||
| gate_states = self.g_proj(query_states) | ||
| delta_states = self.d_proj(value_states) | ||
| attn_bias = ( | ||
| (torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states)) | ||
| .transpose(-1, -2) | ||
| .unsqueeze(-2) | ||
| ) # [batch_size, num_key_value_heads, 1, key_len] | ||
| (torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2) | ||
| ) | ||
|
|
||
| query_states = query_states.view(bsz, seq_len, -1, self.head_dim) | ||
| key_states = key_states.view(bsz, key_len, -1, self.head_dim) | ||
| value_states = value_states.view(bsz, key_len, -1, self.head_dim) | ||
|
|
||
| attn_mask = create_mask( | ||
| attention_bias=attn_bias, | ||
| attention_mask=attention_mask, | ||
| batch_size=query_states.shape[0], | ||
| query_len=query_states.shape[2], | ||
| key_len=key_states.shape[2], | ||
| type="relu", | ||
| ) | ||
| attn_output, attn_weights = flash_sparse_attn_func( | ||
| self, | ||
| query_states.transpose(1, 2).contiguous(), | ||
| key_states.transpose(1, 2).contiguous(), | ||
| value_states.transpose(1, 2).contiguous(), | ||
|
|
||
| attn_output = flash_sparse_attn_func( | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| attn_mask, | ||
| attn_bias, | ||
| softmax_scale=self.scaling, | ||
| is_causal=self.is_causal, | ||
| ) | ||
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | ||
|
|
||
| attn_output = attn_output.reshape(bsz, seq_len, -1).contiguous() | ||
| attn_output = self.o_proj(attn_output) | ||
| return attn_output, attn_weights | ||
|
|
||
| return attn_output |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward method signature and return statement are inconsistent. The return type annotation indicates the method returns a tuple of three elements: tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]], but the actual return statement at line 95 only returns a single tensor (attn_output). This is a breaking API change that contradicts the PR description's claim of maintaining backward compatibility. Either update the return type annotation to match the new single return value, or maintain the original tuple return format for compatibility.
| from typing import Optional, Tuple | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Tuple type has been imported but the Cache import from transformers.cache_utils is still present. Since the Cache object is no longer used in the simplified implementation (past_key_values now uses a simple Tuple), the Cache import should be removed.
| attn_output = self.o_proj(attn_output) | ||
| return attn_output, attn_weights | ||
|
|
||
| return attn_output |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function does not return the updated past_key_values tuple. For proper KV caching in generation scenarios, the method should return the concatenated (key_states, value_states) tuple so that callers can pass it as past_key_values in subsequent forward passes. Without this, the caching mechanism will not work correctly.
| attention_bias=attn_bias, | ||
| attention_mask=attention_mask, | ||
| batch_size=query_states.shape[0], | ||
| query_len=query_states.shape[2], |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The create_mask function call is missing several required parameters. According to the function signature, create_mask requires: attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype, and block_size. The current call only provides attention_bias, query_len, and type. You need to add the missing parameters: attention_mask, batch_size (bsz), key_len, window_size (from config or None), min_dtype (from config or None), and block_size (from config or None).
| query_len=query_states.shape[2], | |
| attention_mask=attention_mask, | |
| batch_size=bsz, | |
| query_len=seq_len, | |
| key_len=key_len, | |
| window_size=getattr(self.config, "window_size", None), | |
| min_dtype=getattr(self.config, "min_dtype", None), | |
| block_size=getattr(self.config, "block_size", None), |
| attn_bias = ( | ||
| (torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states)) | ||
| .transpose(-1, -2) | ||
| .unsqueeze(-2) | ||
| ) # [batch_size, num_key_value_heads, 1, key_len] | ||
| (torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2) | ||
| ) |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of attn_bias has a dimension mismatch. The g_proj takes query_states with shape (bsz, seq_len, num_attention_heads * head_dim) and outputs (bsz, seq_len, num_key_value_heads). The d_proj takes value_states which after cache concatenation has shape (bsz, key_len, num_key_value_heads * head_dim) and outputs (bsz, key_len, num_key_value_heads). When you multiply gate_states (bsz, seq_len, num_key_value_heads) with delta_states (bsz, key_len, num_key_value_heads), this will fail due to incompatible dimensions at dim=1 (seq_len vs key_len). This needs to be an outer product operation or the dimensions need to be broadcasted correctly to produce a (bsz, num_key_value_heads, seq_len, key_len) tensor.
| attn_bias = ( | |
| (torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states)) | |
| .transpose(-1, -2) | |
| .unsqueeze(-2) | |
| ) # [batch_size, num_key_value_heads, 1, key_len] | |
| (torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2) | |
| ) | |
| gate_states = self.g_proj(query_states) # (bsz, seq_len, num_key_value_heads) | |
| delta_states = self.d_proj(value_states) # (bsz, key_len, num_key_value_heads) | |
| gate = torch.sigmoid(gate_states).unsqueeze(2) # (bsz, seq_len, 1, num_key_value_heads) | |
| delta = delta_states.unsqueeze(1) # (bsz, 1, key_len, num_key_value_heads) | |
| attn_bias = (gate * delta).permute(0, 3, 1, 2) # (bsz, num_key_value_heads, seq_len, key_len) |
…lti-head attention
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist