-
Notifications
You must be signed in to change notification settings - Fork 51
Simplify attention mechanisms #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
8784b08
64a56a8
f525841
7df523f
1f5d2cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,48 +1,12 @@ | ||||||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||||||
| from typing import Optional, Tuple | ||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from transformers.cache_utils import Cache | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func | ||||||||||||||||||||||||||
| from flash_sparse_attn.utils.mask import create_mask | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def rotate_half(x): | ||||||||||||||||||||||||||
| """Rotates half the hidden dims of the input.""" | ||||||||||||||||||||||||||
| x1 = x[..., : x.shape[-1] // 2] | ||||||||||||||||||||||||||
| x2 = x[..., x.shape[-1] // 2 :] | ||||||||||||||||||||||||||
| return torch.cat((-x2, x1), dim=-1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | ||||||||||||||||||||||||||
| """Applies Rotary Position Embedding to the query and key tensors. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||
| q (`torch.Tensor`): The query tensor. | ||||||||||||||||||||||||||
| k (`torch.Tensor`): The key tensor. | ||||||||||||||||||||||||||
| cos (`torch.Tensor`): The cosine part of the rotary embedding. | ||||||||||||||||||||||||||
| sin (`torch.Tensor`): The sine part of the rotary embedding. | ||||||||||||||||||||||||||
| position_ids (`torch.Tensor`, *optional*): | ||||||||||||||||||||||||||
| Deprecated and unused. | ||||||||||||||||||||||||||
| unsqueeze_dim (`int`, *optional*, defaults to 1): | ||||||||||||||||||||||||||
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | ||||||||||||||||||||||||||
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | ||||||||||||||||||||||||||
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | ||||||||||||||||||||||||||
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | ||||||||||||||||||||||||||
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | ||||||||||||||||||||||||||
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | ||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| cos = cos.unsqueeze(unsqueeze_dim) | ||||||||||||||||||||||||||
| sin = sin.unsqueeze(unsqueeze_dim) | ||||||||||||||||||||||||||
| q_embed = (q * cos) + (rotate_half(q) * sin) | ||||||||||||||||||||||||||
| k_embed = (k * cos) + (rotate_half(k) * sin) | ||||||||||||||||||||||||||
| return q_embed, k_embed | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| class DynamicMaskAttention(nn.Module): | ||||||||||||||||||||||||||
| def __init__(self, config, layer_idx: Optional[int] = None): | ||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||
|
|
@@ -66,12 +30,12 @@ def __init__(self, config, layer_idx: Optional[int] = None): | |||||||||||||||||||||||||
| self.v_proj = nn.Linear( | ||||||||||||||||||||||||||
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| self.a_proj = nn.Linear( | ||||||||||||||||||||||||||
| self.g_proj = nn.Linear( | ||||||||||||||||||||||||||
| config.num_attention_heads * self.head_dim, | ||||||||||||||||||||||||||
| config.num_key_value_heads, | ||||||||||||||||||||||||||
| bias=False, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| self.dt_proj = nn.Linear( | ||||||||||||||||||||||||||
| self.d_proj = nn.Linear( | ||||||||||||||||||||||||||
| config.num_key_value_heads * self.head_dim, | ||||||||||||||||||||||||||
| config.num_key_value_heads, | ||||||||||||||||||||||||||
| bias=False, | ||||||||||||||||||||||||||
|
|
@@ -83,63 +47,49 @@ def __init__(self, config, layer_idx: Optional[int] = None): | |||||||||||||||||||||||||
| def forward( | ||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||
| hidden_states: torch.Tensor, | ||||||||||||||||||||||||||
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | ||||||||||||||||||||||||||
| attention_mask: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||
| past_key_values: Optional[Cache] = None, | ||||||||||||||||||||||||||
| cache_position: Optional[torch.LongTensor] = None, | ||||||||||||||||||||||||||
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: | ||||||||||||||||||||||||||
| input_shape = hidden_states.shape[:-1] | ||||||||||||||||||||||||||
| hidden_shape = (*input_shape, -1, self.head_dim) | ||||||||||||||||||||||||||
| query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||||||||||||||||||||||||||
| key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||||||||||||||||||||||||||
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||||||||||||||||||||||||||
| cos, sin = position_embeddings | ||||||||||||||||||||||||||
| query_states, key_states = apply_rotary_pos_emb( | ||||||||||||||||||||||||||
| query_states, key_states, cos, sin | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| bsz, seq_len, _ = hidden_states.size() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| query_states = self.q_proj(hidden_states) | ||||||||||||||||||||||||||
| key_states = self.k_proj(hidden_states) | ||||||||||||||||||||||||||
| value_states = self.v_proj(hidden_states) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if past_key_values is not None: | ||||||||||||||||||||||||||
| # sin and cos are specific to RoPE models; cache_position needed for the static cache | ||||||||||||||||||||||||||
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | ||||||||||||||||||||||||||
| key_states, value_states = past_key_values.update( | ||||||||||||||||||||||||||
| key_states, value_states, self.layer_idx, cache_kwargs | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # sampling a_states from query_states | ||||||||||||||||||||||||||
| a_states = self.a_proj( | ||||||||||||||||||||||||||
| query_states.transpose(1, 2).reshape( | ||||||||||||||||||||||||||
| query_states.shape[0], query_states.shape[-2], -1 | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| ) # [batch_size, query_len, num_key_value_heads] | ||||||||||||||||||||||||||
| # sampling dt_states from value_states | ||||||||||||||||||||||||||
| dt_states = self.dt_proj( | ||||||||||||||||||||||||||
| value_states.transpose(1, 2).reshape( | ||||||||||||||||||||||||||
| value_states.shape[0], value_states.shape[-2], -1 | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| ) # [batch_size, key_len, num_key_value_heads] | ||||||||||||||||||||||||||
| # original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V) | ||||||||||||||||||||||||||
| past_key, past_value = past_key_values | ||||||||||||||||||||||||||
| key_states = torch.cat([past_key, key_states], dim=1) | ||||||||||||||||||||||||||
| value_states = torch.cat([past_value, value_states], dim=1) | ||||||||||||||||||||||||||
| key_len = key_states.size(1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| gate_states = self.g_proj(query_states) | ||||||||||||||||||||||||||
| delta_states = self.d_proj(value_states) | ||||||||||||||||||||||||||
| attn_bias = ( | ||||||||||||||||||||||||||
| (torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states)) | ||||||||||||||||||||||||||
| .transpose(-1, -2) | ||||||||||||||||||||||||||
| .unsqueeze(-2) | ||||||||||||||||||||||||||
| ) # [batch_size, num_key_value_heads, 1, key_len] | ||||||||||||||||||||||||||
| (torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2) | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
| attn_bias = ( | |
| (torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states)) | |
| .transpose(-1, -2) | |
| .unsqueeze(-2) | |
| ) # [batch_size, num_key_value_heads, 1, key_len] | |
| (torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2) | |
| ) | |
| gate_states = self.g_proj(query_states) # (bsz, seq_len, num_key_value_heads) | |
| delta_states = self.d_proj(value_states) # (bsz, key_len, num_key_value_heads) | |
| gate = torch.sigmoid(gate_states).unsqueeze(2) # (bsz, seq_len, 1, num_key_value_heads) | |
| delta = delta_states.unsqueeze(1) # (bsz, 1, key_len, num_key_value_heads) | |
| attn_bias = (gate * delta).permute(0, 3, 1, 2) # (bsz, num_key_value_heads, seq_len, key_len) |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing query_states.shape[2] is incorrect. At this point in the code, query_states has been reshaped at line 72 to have dimensions (bsz, seq_len, num_heads, head_dim), so shape[2] refers to the number of heads, not query_len. The query_len should be obtained from shape[1] which is seq_len, or simply use the seq_len variable that's already defined.
| query_len=query_states.shape[2], | |
| query_len=seq_len, |
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The create_mask function call is missing several required parameters. According to the function signature, create_mask requires: attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype, and block_size. The current call only provides attention_bias, query_len, and type. You need to add the missing parameters: attention_mask, batch_size (bsz), key_len, window_size (from config or None), min_dtype (from config or None), and block_size (from config or None).
| query_len=query_states.shape[2], | |
| attention_mask=attention_mask, | |
| batch_size=bsz, | |
| query_len=seq_len, | |
| key_len=key_len, | |
| window_size=getattr(self.config, "window_size", None), | |
| min_dtype=getattr(self.config, "min_dtype", None), | |
| block_size=getattr(self.config, "block_size", None), |
Outdated
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward method signature and return statement are inconsistent. The return type annotation indicates the method returns a tuple of three elements: tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]], but the actual return statement at line 95 only returns a single tensor (attn_output). This is a breaking API change that contradicts the PR description's claim of maintaining backward compatibility. Either update the return type annotation to match the new single return value, or maintain the original tuple return format for compatibility.
Outdated
Copilot
AI
Dec 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function does not return the updated past_key_values tuple. For proper KV caching in generation scenarios, the method should return the concatenated (key_states, value_states) tuple so that callers can pass it as past_key_values in subsequent forward passes. Without this, the caching mechanism will not work correctly.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,46 +1,11 @@ | ||
| from typing import Optional | ||
| from typing import Optional, Tuple | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
|
Comment on lines
+1
to
4
|
||
| from transformers.cache_utils import Cache | ||
|
|
||
| from flash_attn.flash_attn_interface import flash_attn_func | ||
|
|
||
|
|
||
| def rotate_half(x): | ||
| """Rotates half the hidden dims of the input.""" | ||
| x1 = x[..., : x.shape[-1] // 2] | ||
| x2 = x[..., x.shape[-1] // 2 :] | ||
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
|
|
||
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | ||
| """Applies Rotary Position Embedding to the query and key tensors. | ||
|
|
||
| Args: | ||
| q (`torch.Tensor`): The query tensor. | ||
| k (`torch.Tensor`): The key tensor. | ||
| cos (`torch.Tensor`): The cosine part of the rotary embedding. | ||
| sin (`torch.Tensor`): The sine part of the rotary embedding. | ||
| position_ids (`torch.Tensor`, *optional*): | ||
| Deprecated and unused. | ||
| unsqueeze_dim (`int`, *optional*, defaults to 1): | ||
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | ||
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | ||
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | ||
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | ||
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | ||
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | ||
| Returns: | ||
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | ||
| """ | ||
| cos = cos.unsqueeze(unsqueeze_dim) | ||
| sin = sin.unsqueeze(unsqueeze_dim) | ||
| q_embed = (q * cos) + (rotate_half(q) * sin) | ||
| k_embed = (k * cos) + (rotate_half(k) * sin) | ||
| return q_embed, k_embed | ||
|
|
||
|
|
||
| class MultiHeadAttention(nn.Module): | ||
| def __init__(self, config, layer_idx: Optional[int] = None): | ||
| super().__init__() | ||
|
|
@@ -71,40 +36,35 @@ def __init__(self, config, layer_idx: Optional[int] = None): | |
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| past_key_values: Optional[Cache] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: | ||
| input_shape = hidden_states.shape[:-1] | ||
| hidden_shape = (*input_shape, -1, self.head_dim) | ||
|
|
||
| query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||
| key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||
| bsz, seq_len, _ = hidden_states.size() | ||
|
|
||
| cos, sin = position_embeddings | ||
| query_states, key_states = apply_rotary_pos_emb( | ||
| query_states, key_states, cos, sin | ||
| ) | ||
| query_states = self.q_proj(hidden_states) | ||
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| if past_key_values is not None: | ||
| # sin and cos are specific to RoPE models; cache_position needed for the static cache | ||
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | ||
| key_states, value_states = past_key_values.update( | ||
| key_states, value_states, self.layer_idx, cache_kwargs | ||
| ) | ||
|
|
||
| attn_output, attn_weights = flash_attn_func( | ||
| query_states.transpose(1, 2).contiguous(), | ||
| key_states.transpose(1, 2).contiguous(), | ||
| value_states.transpose(1, 2).contiguous(), | ||
| past_key, past_value = past_key_values | ||
| key_states = torch.cat([past_key, key_states], dim=1) | ||
| value_states = torch.cat([past_value, value_states], dim=1) | ||
| key_len = key_states.size(1) | ||
|
|
||
| query_states = query_states.view(bsz, seq_len, -1, self.head_dim) | ||
| key_states = key_states.view(bsz, key_len, -1, self.head_dim) | ||
| value_states = value_states.view(bsz, key_len, -1, self.head_dim) | ||
|
|
||
| attn_output = flash_attn_func( | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| softmax_scale=self.scaling, | ||
| causal=self.is_causal, | ||
| ) | ||
|
|
||
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | ||
| attn_output = attn_output.reshape(bsz, seq_len, -1).contiguous() | ||
| attn_output = self.o_proj(attn_output) | ||
|
|
||
| return attn_output, attn_weights | ||
| return attn_output | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Tuple type has been imported but the Cache import from transformers.cache_utils is still present. Since the Cache object is no longer used in the simplified implementation (past_key_values now uses a simple Tuple), the Cache import should be removed.