Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 33 additions & 83 deletions flash_sparse_attn/modules/dynamic_mask_attention.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,12 @@
from typing import Optional
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

Comment on lines +1 to 4
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Tuple type has been imported but the Cache import from transformers.cache_utils is still present. Since the Cache object is no longer used in the simplified implementation (past_key_values now uses a simple Tuple), the Cache import should be removed.

Copilot uses AI. Check for mistakes.
from transformers.cache_utils import Cache

from flash_sparse_attn.flash_sparse_attn_interface import flash_sparse_attn_func
from flash_sparse_attn.utils.mask import create_mask


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class DynamicMaskAttention(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
Expand All @@ -66,12 +30,12 @@ def __init__(self, config, layer_idx: Optional[int] = None):
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
)
self.a_proj = nn.Linear(
self.g_proj = nn.Linear(
config.num_attention_heads * self.head_dim,
config.num_key_value_heads,
bias=False,
)
self.dt_proj = nn.Linear(
self.d_proj = nn.Linear(
config.num_key_value_heads * self.head_dim,
config.num_key_value_heads,
bias=False,
Expand All @@ -83,63 +47,49 @@ def __init__(self, config, layer_idx: Optional[int] = None):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
bsz, seq_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# sampling a_states from query_states
a_states = self.a_proj(
query_states.transpose(1, 2).reshape(
query_states.shape[0], query_states.shape[-2], -1
)
) # [batch_size, query_len, num_key_value_heads]
# sampling dt_states from value_states
dt_states = self.dt_proj(
value_states.transpose(1, 2).reshape(
value_states.shape[0], value_states.shape[-2], -1
)
) # [batch_size, key_len, num_key_value_heads]
# original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V)
past_key, past_value = past_key_values
key_states = torch.cat([past_key, key_states], dim=1)
value_states = torch.cat([past_value, value_states], dim=1)
key_len = key_states.size(1)

gate_states = self.g_proj(query_states)
delta_states = self.d_proj(value_states)
attn_bias = (
(torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states))
.transpose(-1, -2)
.unsqueeze(-2)
) # [batch_size, num_key_value_heads, 1, key_len]
(torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2)
)
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation of attn_bias has a dimension mismatch. The g_proj takes query_states with shape (bsz, seq_len, num_attention_heads * head_dim) and outputs (bsz, seq_len, num_key_value_heads). The d_proj takes value_states which after cache concatenation has shape (bsz, key_len, num_key_value_heads * head_dim) and outputs (bsz, key_len, num_key_value_heads). When you multiply gate_states (bsz, seq_len, num_key_value_heads) with delta_states (bsz, key_len, num_key_value_heads), this will fail due to incompatible dimensions at dim=1 (seq_len vs key_len). This needs to be an outer product operation or the dimensions need to be broadcasted correctly to produce a (bsz, num_key_value_heads, seq_len, key_len) tensor.

Suggested change
attn_bias = (
(torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states))
.transpose(-1, -2)
.unsqueeze(-2)
) # [batch_size, num_key_value_heads, 1, key_len]
(torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2)
)
gate_states = self.g_proj(query_states) # (bsz, seq_len, num_key_value_heads)
delta_states = self.d_proj(value_states) # (bsz, key_len, num_key_value_heads)
gate = torch.sigmoid(gate_states).unsqueeze(2) # (bsz, seq_len, 1, num_key_value_heads)
delta = delta_states.unsqueeze(1) # (bsz, 1, key_len, num_key_value_heads)
attn_bias = (gate * delta).permute(0, 3, 1, 2) # (bsz, num_key_value_heads, seq_len, key_len)

Copilot uses AI. Check for mistakes.

query_states = query_states.view(bsz, seq_len, -1, self.head_dim)
key_states = key_states.view(bsz, key_len, -1, self.head_dim)
value_states = value_states.view(bsz, key_len, -1, self.head_dim)

attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=attention_mask,
batch_size=query_states.shape[0],
query_len=query_states.shape[2],
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indexing query_states.shape[2] is incorrect. At this point in the code, query_states has been reshaped at line 72 to have dimensions (bsz, seq_len, num_heads, head_dim), so shape[2] refers to the number of heads, not query_len. The query_len should be obtained from shape[1] which is seq_len, or simply use the seq_len variable that's already defined.

Suggested change
query_len=query_states.shape[2],
query_len=seq_len,

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The create_mask function call is missing several required parameters. According to the function signature, create_mask requires: attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype, and block_size. The current call only provides attention_bias, query_len, and type. You need to add the missing parameters: attention_mask, batch_size (bsz), key_len, window_size (from config or None), min_dtype (from config or None), and block_size (from config or None).

Suggested change
query_len=query_states.shape[2],
attention_mask=attention_mask,
batch_size=bsz,
query_len=seq_len,
key_len=key_len,
window_size=getattr(self.config, "window_size", None),
min_dtype=getattr(self.config, "min_dtype", None),
block_size=getattr(self.config, "block_size", None),

Copilot uses AI. Check for mistakes.
key_len=key_states.shape[2],
type="relu",
)
attn_output, attn_weights = flash_sparse_attn_func(
self,
query_states.transpose(1, 2).contiguous(),
key_states.transpose(1, 2).contiguous(),
value_states.transpose(1, 2).contiguous(),

attn_output = flash_sparse_attn_func(
query_states,
key_states,
value_states,
attn_mask,
attn_bias,
softmax_scale=self.scaling,
is_causal=self.is_causal,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

attn_output = attn_output.reshape(bsz, seq_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

return attn_output
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward method signature and return statement are inconsistent. The return type annotation indicates the method returns a tuple of three elements: tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]], but the actual return statement at line 95 only returns a single tensor (attn_output). This is a breaking API change that contradicts the PR description's claim of maintaining backward compatibility. Either update the return type annotation to match the new single return value, or maintain the original tuple return format for compatibility.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function does not return the updated past_key_values tuple. For proper KV caching in generation scenarios, the method should return the concatenated (key_states, value_states) tuple so that callers can pass it as past_key_values in subsequent forward passes. Without this, the caching mechanism will not work correctly.

Copilot uses AI. Check for mistakes.
82 changes: 21 additions & 61 deletions flash_sparse_attn/modules/multi_head_attention.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,11 @@
from typing import Optional
from typing import Optional, Tuple
import torch
import torch.nn as nn

Comment on lines +1 to 4
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Tuple type has been imported but the Cache import from transformers.cache_utils is still present. Since the Cache object is no longer used in the simplified implementation (past_key_values now uses a simple Tuple), the Cache import should be removed.

Copilot uses AI. Check for mistakes.
from transformers.cache_utils import Cache

from flash_attn.flash_attn_interface import flash_attn_func


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class MultiHeadAttention(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
Expand Down Expand Up @@ -71,40 +36,35 @@ def __init__(self, config, layer_idx: Optional[int] = None):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
bsz, seq_len, _ = hidden_states.size()

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

attn_output, attn_weights = flash_attn_func(
query_states.transpose(1, 2).contiguous(),
key_states.transpose(1, 2).contiguous(),
value_states.transpose(1, 2).contiguous(),
past_key, past_value = past_key_values
key_states = torch.cat([past_key, key_states], dim=1)
value_states = torch.cat([past_value, value_states], dim=1)
key_len = key_states.size(1)

query_states = query_states.view(bsz, seq_len, -1, self.head_dim)
key_states = key_states.view(bsz, key_len, -1, self.head_dim)
value_states = value_states.view(bsz, key_len, -1, self.head_dim)

attn_output = flash_attn_func(
query_states,
key_states,
value_states,
softmax_scale=self.scaling,
causal=self.is_causal,
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output.reshape(bsz, seq_len, -1).contiguous()
attn_output = self.o_proj(attn_output)

return attn_output, attn_weights
return attn_output