Skip to content

Hello, I would like to ask the HMLA implementation is geniune? There is no KV Cache implementation #5

@aaababaaz

Description

@aaababaaz

helm/helm/modules/hmla.py

Lines 126 to 176 in e8b4821

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], attn_impl='naive'):
"""
Forward pass for the Multi-Headed Attention Layer (MLA).
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
start_pos (int): Starting position in the sequence for caching.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, embed_dim = x.size()
end_pos = start_pos + seqlen
if self.q_lora_rank == 0:
q = self.wq(x, return_space=True)
else:
q = self.wq_b(self.q_norm(self.wq_a(x, return_space=True), space_only=True), return_space=True)
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim - 1) #space-like
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim - 1], dim=-1) #space-like
q_pe = apply_rotary_emb(q_pe, freqs_cis) #space-like
kv = self.wkv_a(x, return_space=True)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim - 1], dim=-1) #space-like
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) #space-like
q = torch.cat([q_nope, q_pe], dim=-1) #space-like
kv = self.wkv_b(self.kv_norm(kv, space_only=True), return_space=True) #space-like
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim - 1)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim - 1], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
# self.k_cache[:bsz, start_pos:end_pos] = k
# self.v_cache[:bsz, start_pos:end_pos] = v
# MLA based on hyperbolic distance
qs = self.project(q)
ks = self.project(k)
scores = 2 * self.manifold.c + 2 * self.manifold.cinner(qs.transpose(1, 2), ks.transpose(1, 2)) # [B, S, N, N]
scores = scores / self.softmax_scale + self.bias
if mask is not None:
mask = self.shape_mask(mask, bsz, self.n_local_heads, seqlen)
scores = scores.masked_fill(mask, -1e18)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
# vs = self.project(self.v_cache[:bsz, :end_pos])
vs = self.project(v)
x = self.manifold.lorentzian_centroid(vs.transpose(1, 2), scores).transpose(1, 2) #[B, S, H, N]
x = self.wo(x.flatten(2))
return x

There, I didnt see something like down below:

Example NoPE MLA Code
class MLALayerOptimized(nn.Module):
    """
    一个纯粹的、无位置编码 (NoPE) 且完全向量化的
    Multi-head Latent Attention (MLA) 的优化实现。
    - 训练/Prefill模式: 使用 F.scaled_dot_product_attention 以获得最佳性能 (支持 Flash Attention)。
    - 推理模式: 实现论文中描述的、通过恒等变换达成的 MQA 式计算优化。
    - 支持 Prefill 和单步解码。
    - 解决了 c_norm 在不同路径下的逻辑一致性问题。
    """
    def __init__(self, d_model: int, num_heads: int, d_latent: int, d_head: int = None, output_dim: int = None, **kwargs):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_latent = d_latent
        self.d_head = d_head if d_head is not None else d_model // num_heads

        # 确保 d_head * num_heads 不会出错
        self.inner_dim = self.num_heads * self.d_head
        # 投影矩阵
        self.W_q = nn.Linear(d_model, self.inner_dim, bias=False)
        self.W_c = nn.Linear(d_model, d_latent, bias=False)
        self.W_k = nn.Linear(d_latent, self.inner_dim, bias=False)
        self.W_v = nn.Linear(d_latent, self.inner_dim, bias=False)
        self.W_o = nn.Linear(self.inner_dim, d_model if not output_dim else output_dim, bias=False)
        self.c_norm = nn.RMSNorm(d_latent)
        self.q_norm = nn.RMSNorm(self.inner_dim)

    def forward(self, x: torch.Tensor, use_cache: bool = False, cache: torch.Tensor = None, attn_mask=None, **kwargs):
        batch_size, seq_len, _ = x.shape

        
        # ------------------------------------------------------------------
        # 路径 1: 单步解码 (Decoding) - 当且仅当 use_cache=True 且 cache 已存在
        # ------------------------------------------------------------------
 
        if use_cache and cache is not None:
            if seq_len != 1:
                raise ValueError(f"Decoding with cache requires seq_len=1. cache: {cache}")
            

            # 1. 计算当前 token 的 c 并更新 cache
            c = self.W_c(x) # x shape: (B, 1, d_model) -> c shape: (B, 1, d_latent)
            # if hasattr(self, 'c_norm'):
            c = self.c_norm(c)

            c_full = torch.cat([cache, c], dim=1) if cache is not None else c

            # 2. 计算当前 token 的 Q
            q = self.W_q(x) # (B, 1, inner_dim)
            q = self.q_norm(q) # Apply q_norm: 这一步是安全的
            q_current = q.view(batch_size, 1, self.num_heads, self.d_head)

            # 3. 核心优化:实现 q' = q @ Wk.T
            # q_current: (B, 1, H, D_h)
            # W_k.weight: (H*D_h, D_l) -> (H, D_h, D_l)
            # q_prime: (B, 1, H, D_l)
            W_k_reshaped = self.W_k.weight.view(self.num_heads, self.d_head, self.d_latent)
            q_prime = torch.einsum('bqhd,hdl->bqhl', q_current, W_k_reshaped)

            # 4. 计算注意力分数 q' @ c.T
            # q_prime: (B, 1, H, D_l)
            # c_full: (B, L, D_l)
            # attn_scores: (B, 1, H, L)
            attn_scores = torch.einsum('bqhl,bkl->bqhk', q_prime, c_full) / math.sqrt(self.d_head)

            # 5. 计算权重并对 c 进行加权求和 ("先求和")
            # attn_weights: (B, 1, H, L)
            # intermediate: (B, 1, H, D_l)
            attn_weights = F.softmax(attn_scores - attn_scores.max(dim=-1, keepdim=True)[0] , dim=-1)
            intermediate = torch.einsum('bqhk,bkl->bqhl', attn_weights, c_full)

            # 6. 用 Wv 对中间结果进行变换 ("后变换")
            # W_v.weight: (H*D_h, D_l) -> (H, D_h, D_l)
            # head_output: (B, 1, H, D_h)
            W_v_reshaped = self.W_v.weight.view(self.num_heads, self.d_head, self.d_latent)
            head_output = torch.einsum('bqhl,hdl->bqhd', intermediate, W_v_reshaped)
            # 7. 合并 head 并输出
            combined_heads = head_output.contiguous().view(batch_size, 1, -1)
            output = self.W_o(combined_heads)
            return output, c_full
        # ------------------------------------------------------------------
        # 路径 2: 并行处理 (训练 或 推理的 Prefill 阶段)
        # ------------------------------------------------------------------
        c = self.W_c(x) # (B, L, d_latent)

        q = self.W_q(x)
        q = q.view(batch_size, seq_len, self.num_heads, self.d_head)
        k = self.W_k(c).view(batch_size, seq_len, self.num_heads, self.d_head)
        v = self.W_v(c).view(batch_size, seq_len, self.num_heads, self.d_head)

        # 使用 PyTorch 2.0+ 的高效实现,is_causal=True 会自动应用因果掩码
        # (B, L, H, D) -> (B, H, L, D) for SDPA
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        # F.scaled_dot_product_attention 内部处理 softmax 和缩放
        head_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=True)
        combined_heads = head_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        output = self.W_o(combined_heads)
        cache_to_return = c if use_cache else None
        return output, cache_to_return

I mean, HMLA doesn't looks like there is any to about KV Cache, It's more looking like MHA without KV-Cache inference running path.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions