Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • This update simplifies the multi-head attention and dynamic mask attention mechanisms by removing unnecessary components and aligning with flash attention expectations.

Root Cause

  • The previous implementation included redundant rotary embedding utilities and complex cache handling that were not needed for the intended functionality.

Changes

  • Removed rotary embedding utilities and associated parameters.
  • Streamlined cache handling to use simple tuple concatenation.
  • Introduced new gate/delta projections and a sigmoid-based bias for mask creation.

Reproduction

  • No specific bug was identified, but the changes improve the overall efficiency and clarity of the attention mechanisms.

Tests

  • Existing tests were validated to ensure no regressions occurred with the simplified implementations.

Compatibility

  • The changes maintain backward compatibility with the existing attention mechanisms.

Checklist

  • Linked issue provided
  • Adds or updates tests
  • Updates docs if needed
  • No perf regressions

Removes rotary embedding utilities and associated parameters so the module only consumes projections required by flash attention.
Switches cache handling to simple tuple concatenation and aligns tensor reshapes with flash_attn expectations while dropping unused attention weights.
Removes RoPE-specific helpers and cache plumbing so the module works with plain tensors.
Introduces new gate/delta projections and sigmoid-based bias to streamline mask creation.
Drops the unused attention weight return to match flash sparse expectations.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to simplify the multi-head attention and dynamic mask attention mechanisms by removing rotary position embeddings and streamlining cache handling. However, the implementation contains several critical issues that need to be addressed.

Key Changes:

  • Removed rotary embedding utility functions and position embedding parameters
  • Simplified cache handling from Cache objects to simple tuple concatenation
  • Renamed projection layers (a_proj → g_proj, dt_proj → d_proj) and changed attention bias calculation from mean + softplus to sigmoid

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 9 comments.

File Description
flash_sparse_attn/modules/multi_head_attention.py Removed rotary embeddings, simplified cache handling to tuples, and removed attn_weights return value
flash_sparse_attn/modules/dynamic_mask_attention.py Same changes as multi_head_attention.py plus renamed projection layers and simplified attention bias calculation

Critical Issues Found:

  1. API Breaking Changes: Both attention modules now return only the attention output instead of the documented tuple of three values, contradicting the PR's backward compatibility claim
  2. Missing Parameters: The create_mask function call is missing required parameters (attention_mask, batch_size, key_len, window_size, min_dtype, block_size)
  3. Dimension Mismatch: The attention bias calculation attempts to multiply tensors with incompatible dimensions (seq_len vs key_len)
  4. Incorrect Indexing: Using query_states.shape[2] after reshaping returns the number of heads instead of query_len
  5. Broken Caching: The updated past_key_values tuple is not returned, breaking KV caching for generation scenarios

These issues will cause runtime errors and prevent the code from functioning correctly. The PR needs significant revisions before it can be merged.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1 to 4
from typing import Optional, Tuple
import torch
import torch.nn as nn

Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Tuple type has been imported but the Cache import from transformers.cache_utils is still present. Since the Cache object is no longer used in the simplified implementation (past_key_values now uses a simple Tuple), the Cache import should be removed.

Copilot uses AI. Check for mistakes.
attention_bias=attn_bias,
attention_mask=attention_mask,
batch_size=query_states.shape[0],
query_len=query_states.shape[2],
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indexing query_states.shape[2] is incorrect. At this point in the code, query_states has been reshaped at line 72 to have dimensions (bsz, seq_len, num_heads, head_dim), so shape[2] refers to the number of heads, not query_len. The query_len should be obtained from shape[1] which is seq_len, or simply use the seq_len variable that's already defined.

Suggested change
query_len=query_states.shape[2],
query_len=seq_len,

Copilot uses AI. Check for mistakes.
Comment on lines 55 to 95

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# sampling a_states from query_states
a_states = self.a_proj(
query_states.transpose(1, 2).reshape(
query_states.shape[0], query_states.shape[-2], -1
)
) # [batch_size, query_len, num_key_value_heads]
# sampling dt_states from value_states
dt_states = self.dt_proj(
value_states.transpose(1, 2).reshape(
value_states.shape[0], value_states.shape[-2], -1
)
) # [batch_size, key_len, num_key_value_heads]
# original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V)
past_key, past_value = past_key_values
key_states = torch.cat([past_key, key_states], dim=1)
value_states = torch.cat([past_value, value_states], dim=1)
key_len = key_states.size(1)

gate_states = self.g_proj(query_states)
delta_states = self.d_proj(value_states)
attn_bias = (
(torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states))
.transpose(-1, -2)
.unsqueeze(-2)
) # [batch_size, num_key_value_heads, 1, key_len]
(torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2)
)

query_states = query_states.view(bsz, seq_len, -1, self.head_dim)
key_states = key_states.view(bsz, key_len, -1, self.head_dim)
value_states = value_states.view(bsz, key_len, -1, self.head_dim)

attn_mask = create_mask(
attention_bias=attn_bias,
attention_mask=attention_mask,
batch_size=query_states.shape[0],
query_len=query_states.shape[2],
key_len=key_states.shape[2],
type="relu",
)
attn_output, attn_weights = flash_sparse_attn_func(
self,
query_states.transpose(1, 2).contiguous(),
key_states.transpose(1, 2).contiguous(),
value_states.transpose(1, 2).contiguous(),

attn_output = flash_sparse_attn_func(
query_states,
key_states,
value_states,
attn_mask,
attn_bias,
softmax_scale=self.scaling,
is_causal=self.is_causal,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

attn_output = attn_output.reshape(bsz, seq_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

return attn_output
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward method signature and return statement are inconsistent. The return type annotation indicates the method returns a tuple of three elements: tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]], but the actual return statement at line 95 only returns a single tensor (attn_output). This is a breaking API change that contradicts the PR description's claim of maintaining backward compatibility. Either update the return type annotation to match the new single return value, or maintain the original tuple return format for compatibility.

Copilot uses AI. Check for mistakes.
Comment on lines +1 to 4
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Tuple type has been imported but the Cache import from transformers.cache_utils is still present. Since the Cache object is no longer used in the simplified implementation (past_key_values now uses a simple Tuple), the Cache import should be removed.

Copilot uses AI. Check for mistakes.
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

return attn_output
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function does not return the updated past_key_values tuple. For proper KV caching in generation scenarios, the method should return the concatenated (key_states, value_states) tuple so that callers can pass it as past_key_values in subsequent forward passes. Without this, the caching mechanism will not work correctly.

Copilot uses AI. Check for mistakes.
attention_bias=attn_bias,
attention_mask=attention_mask,
batch_size=query_states.shape[0],
query_len=query_states.shape[2],
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The create_mask function call is missing several required parameters. According to the function signature, create_mask requires: attention_bias, attention_mask, batch_size, query_len, key_len, window_size, min_dtype, and block_size. The current call only provides attention_bias, query_len, and type. You need to add the missing parameters: attention_mask, batch_size (bsz), key_len, window_size (from config or None), min_dtype (from config or None), and block_size (from config or None).

Suggested change
query_len=query_states.shape[2],
attention_mask=attention_mask,
batch_size=bsz,
query_len=seq_len,
key_len=key_len,
window_size=getattr(self.config, "window_size", None),
min_dtype=getattr(self.config, "min_dtype", None),
block_size=getattr(self.config, "block_size", None),

Copilot uses AI. Check for mistakes.
Comment on lines 68 to 70
attn_bias = (
(torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states))
.transpose(-1, -2)
.unsqueeze(-2)
) # [batch_size, num_key_value_heads, 1, key_len]
(torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2)
)
Copy link

Copilot AI Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calculation of attn_bias has a dimension mismatch. The g_proj takes query_states with shape (bsz, seq_len, num_attention_heads * head_dim) and outputs (bsz, seq_len, num_key_value_heads). The d_proj takes value_states which after cache concatenation has shape (bsz, key_len, num_key_value_heads * head_dim) and outputs (bsz, key_len, num_key_value_heads). When you multiply gate_states (bsz, seq_len, num_key_value_heads) with delta_states (bsz, key_len, num_key_value_heads), this will fail due to incompatible dimensions at dim=1 (seq_len vs key_len). This needs to be an outer product operation or the dimensions need to be broadcasted correctly to produce a (bsz, num_key_value_heads, seq_len, key_len) tensor.

Suggested change
attn_bias = (
(torch.mean(a_states, dim=1, keepdim=True) * F.softplus(dt_states))
.transpose(-1, -2)
.unsqueeze(-2)
) # [batch_size, num_key_value_heads, 1, key_len]
(torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2)
)
gate_states = self.g_proj(query_states) # (bsz, seq_len, num_key_value_heads)
delta_states = self.d_proj(value_states) # (bsz, key_len, num_key_value_heads)
gate = torch.sigmoid(gate_states).unsqueeze(2) # (bsz, seq_len, 1, num_key_value_heads)
delta = delta_states.unsqueeze(1) # (bsz, 1, key_len, num_key_value_heads)
attn_bias = (gate * delta).permute(0, 3, 1, 2) # (bsz, num_key_value_heads, seq_len, key_len)

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 337f552 into main Dec 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants