-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Lines 126 to 176 in e8b4821
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], attn_impl='naive'): | |
| """ | |
| Forward pass for the Multi-Headed Attention Layer (MLA). | |
| Args: | |
| x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). | |
| start_pos (int): Starting position in the sequence for caching. | |
| freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. | |
| mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. | |
| Returns: | |
| torch.Tensor: Output tensor with the same shape as the input. | |
| """ | |
| bsz, seqlen, embed_dim = x.size() | |
| end_pos = start_pos + seqlen | |
| if self.q_lora_rank == 0: | |
| q = self.wq(x, return_space=True) | |
| else: | |
| q = self.wq_b(self.q_norm(self.wq_a(x, return_space=True), space_only=True), return_space=True) | |
| q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim - 1) #space-like | |
| q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim - 1], dim=-1) #space-like | |
| q_pe = apply_rotary_emb(q_pe, freqs_cis) #space-like | |
| kv = self.wkv_a(x, return_space=True) | |
| kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim - 1], dim=-1) #space-like | |
| k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) #space-like | |
| q = torch.cat([q_nope, q_pe], dim=-1) #space-like | |
| kv = self.wkv_b(self.kv_norm(kv, space_only=True), return_space=True) #space-like | |
| kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim - 1) | |
| k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim - 1], dim=-1) | |
| k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) | |
| # self.k_cache[:bsz, start_pos:end_pos] = k | |
| # self.v_cache[:bsz, start_pos:end_pos] = v | |
| # MLA based on hyperbolic distance | |
| qs = self.project(q) | |
| ks = self.project(k) | |
| scores = 2 * self.manifold.c + 2 * self.manifold.cinner(qs.transpose(1, 2), ks.transpose(1, 2)) # [B, S, N, N] | |
| scores = scores / self.softmax_scale + self.bias | |
| if mask is not None: | |
| mask = self.shape_mask(mask, bsz, self.n_local_heads, seqlen) | |
| scores = scores.masked_fill(mask, -1e18) | |
| scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) | |
| # vs = self.project(self.v_cache[:bsz, :end_pos]) | |
| vs = self.project(v) | |
| x = self.manifold.lorentzian_centroid(vs.transpose(1, 2), scores).transpose(1, 2) #[B, S, H, N] | |
| x = self.wo(x.flatten(2)) | |
| return x |
There, I didnt see something like down below:
Example NoPE MLA Code
class MLALayerOptimized(nn.Module):
"""
一个纯粹的、无位置编码 (NoPE) 且完全向量化的
Multi-head Latent Attention (MLA) 的优化实现。
- 训练/Prefill模式: 使用 F.scaled_dot_product_attention 以获得最佳性能 (支持 Flash Attention)。
- 推理模式: 实现论文中描述的、通过恒等变换达成的 MQA 式计算优化。
- 支持 Prefill 和单步解码。
- 解决了 c_norm 在不同路径下的逻辑一致性问题。
"""
def __init__(self, d_model: int, num_heads: int, d_latent: int, d_head: int = None, output_dim: int = None, **kwargs):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_latent = d_latent
self.d_head = d_head if d_head is not None else d_model // num_heads
# 确保 d_head * num_heads 不会出错
self.inner_dim = self.num_heads * self.d_head
# 投影矩阵
self.W_q = nn.Linear(d_model, self.inner_dim, bias=False)
self.W_c = nn.Linear(d_model, d_latent, bias=False)
self.W_k = nn.Linear(d_latent, self.inner_dim, bias=False)
self.W_v = nn.Linear(d_latent, self.inner_dim, bias=False)
self.W_o = nn.Linear(self.inner_dim, d_model if not output_dim else output_dim, bias=False)
self.c_norm = nn.RMSNorm(d_latent)
self.q_norm = nn.RMSNorm(self.inner_dim)
def forward(self, x: torch.Tensor, use_cache: bool = False, cache: torch.Tensor = None, attn_mask=None, **kwargs):
batch_size, seq_len, _ = x.shape
# ------------------------------------------------------------------
# 路径 1: 单步解码 (Decoding) - 当且仅当 use_cache=True 且 cache 已存在
# ------------------------------------------------------------------
if use_cache and cache is not None:
if seq_len != 1:
raise ValueError(f"Decoding with cache requires seq_len=1. cache: {cache}")
# 1. 计算当前 token 的 c 并更新 cache
c = self.W_c(x) # x shape: (B, 1, d_model) -> c shape: (B, 1, d_latent)
# if hasattr(self, 'c_norm'):
c = self.c_norm(c)
c_full = torch.cat([cache, c], dim=1) if cache is not None else c
# 2. 计算当前 token 的 Q
q = self.W_q(x) # (B, 1, inner_dim)
q = self.q_norm(q) # Apply q_norm: 这一步是安全的
q_current = q.view(batch_size, 1, self.num_heads, self.d_head)
# 3. 核心优化:实现 q' = q @ Wk.T
# q_current: (B, 1, H, D_h)
# W_k.weight: (H*D_h, D_l) -> (H, D_h, D_l)
# q_prime: (B, 1, H, D_l)
W_k_reshaped = self.W_k.weight.view(self.num_heads, self.d_head, self.d_latent)
q_prime = torch.einsum('bqhd,hdl->bqhl', q_current, W_k_reshaped)
# 4. 计算注意力分数 q' @ c.T
# q_prime: (B, 1, H, D_l)
# c_full: (B, L, D_l)
# attn_scores: (B, 1, H, L)
attn_scores = torch.einsum('bqhl,bkl->bqhk', q_prime, c_full) / math.sqrt(self.d_head)
# 5. 计算权重并对 c 进行加权求和 ("先求和")
# attn_weights: (B, 1, H, L)
# intermediate: (B, 1, H, D_l)
attn_weights = F.softmax(attn_scores - attn_scores.max(dim=-1, keepdim=True)[0] , dim=-1)
intermediate = torch.einsum('bqhk,bkl->bqhl', attn_weights, c_full)
# 6. 用 Wv 对中间结果进行变换 ("后变换")
# W_v.weight: (H*D_h, D_l) -> (H, D_h, D_l)
# head_output: (B, 1, H, D_h)
W_v_reshaped = self.W_v.weight.view(self.num_heads, self.d_head, self.d_latent)
head_output = torch.einsum('bqhl,hdl->bqhd', intermediate, W_v_reshaped)
# 7. 合并 head 并输出
combined_heads = head_output.contiguous().view(batch_size, 1, -1)
output = self.W_o(combined_heads)
return output, c_full
# ------------------------------------------------------------------
# 路径 2: 并行处理 (训练 或 推理的 Prefill 阶段)
# ------------------------------------------------------------------
c = self.W_c(x) # (B, L, d_latent)
q = self.W_q(x)
q = q.view(batch_size, seq_len, self.num_heads, self.d_head)
k = self.W_k(c).view(batch_size, seq_len, self.num_heads, self.d_head)
v = self.W_v(c).view(batch_size, seq_len, self.num_heads, self.d_head)
# 使用 PyTorch 2.0+ 的高效实现,is_causal=True 会自动应用因果掩码
# (B, L, H, D) -> (B, H, L, D) for SDPA
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# F.scaled_dot_product_attention 内部处理 softmax 和缩放
head_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=True)
combined_heads = head_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
output = self.W_o(combined_heads)
cache_to_return = c if use_cache else None
return output, cache_to_returnI mean, HMLA doesn't looks like there is any to about KV Cache, It's more looking like MHA without KV-Cache inference running path.
Metadata
Metadata
Assignees
Labels
No labels