|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Optional |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.attention import Attention |
| 9 | +from vllm.config import CacheConfig |
| 10 | +from vllm.model_executor.custom_op import CustomOp |
| 11 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
| 12 | + |
| 13 | + |
| 14 | +@dataclass |
| 15 | +class MLAModules: |
| 16 | + """Modules used in MLA. |
| 17 | + """ |
| 18 | + kv_a_layernorm: torch.nn.Module |
| 19 | + kv_b_proj: torch.nn.Module |
| 20 | + rotary_emb: torch.nn.Module |
| 21 | + o_proj: torch.nn.Module |
| 22 | + fused_qkv_a_proj: Optional[torch.nn.Module] |
| 23 | + kv_a_proj_with_mqa: Optional[torch.nn.Module] |
| 24 | + q_a_layernorm: Optional[torch.nn.Module] |
| 25 | + q_b_proj: Optional[torch.nn.Module] |
| 26 | + q_proj: Optional[torch.nn.Module] |
| 27 | + |
| 28 | + |
| 29 | +@CustomOp.register("multi_head_latent_attention") |
| 30 | +class MultiHeadLatentAttention(CustomOp): |
| 31 | + """MLA layer registered as CustomOp. |
| 32 | + Note that currently MLA ignores the enable/disable mechanism of CustomOp |
| 33 | + because there is only one in-tree implementation in forward_native. |
| 34 | + TODO: implement this with a new PluggableLayer mechanism. |
| 35 | +
|
| 36 | + This class takes positions and hidden_states as input. |
| 37 | + The input tensors can either contain prefill tokens or decode tokens. |
| 38 | + The class does the following: |
| 39 | +
|
| 40 | + 1. MLA Preprocess. |
| 41 | + 2. Perform multi-head attention to prefill tokens and |
| 42 | + multi-query attention to decode tokens separately. |
| 43 | + 3. Return the output tensor. |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + hidden_size: int, |
| 49 | + num_heads: int, |
| 50 | + scale: float, |
| 51 | + qk_nope_head_dim: int, |
| 52 | + qk_rope_head_dim: int, |
| 53 | + v_head_dim: int, |
| 54 | + q_lora_rank: Optional[int], |
| 55 | + kv_lora_rank: int, |
| 56 | + mla_modules: MLAModules, |
| 57 | + cache_config: Optional[CacheConfig] = None, |
| 58 | + quant_config: Optional[QuantizationConfig] = None, |
| 59 | + prefix: str = "", |
| 60 | + ) -> None: |
| 61 | + super().__init__() |
| 62 | + self.hidden_size = hidden_size |
| 63 | + self.qk_nope_head_dim = qk_nope_head_dim |
| 64 | + self.qk_rope_head_dim = qk_rope_head_dim |
| 65 | + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim |
| 66 | + self.v_head_dim = v_head_dim |
| 67 | + self.q_lora_rank = q_lora_rank |
| 68 | + self.kv_lora_rank = kv_lora_rank |
| 69 | + self.num_heads = num_heads |
| 70 | + self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj |
| 71 | + self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa |
| 72 | + self.q_a_layernorm = mla_modules.q_a_layernorm |
| 73 | + self.q_b_proj = mla_modules.q_b_proj |
| 74 | + self.q_proj = mla_modules.q_proj |
| 75 | + self.kv_a_layernorm = mla_modules.kv_a_layernorm |
| 76 | + self.kv_b_proj = mla_modules.kv_b_proj |
| 77 | + self.rotary_emb = mla_modules.rotary_emb |
| 78 | + self.o_proj = mla_modules.o_proj |
| 79 | + |
| 80 | + # In the MLA backend, kv_cache includes both k_c and |
| 81 | + # pe (i.e. decoupled position embeddings). In particular, |
| 82 | + # the concat_and_cache_mla op requires |
| 83 | + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) |
| 84 | + # i.e. |
| 85 | + # kv_lora_rank + qk_rope_head_dim == head_size |
| 86 | + self.mla_attn = Attention( |
| 87 | + num_heads=self.num_heads, |
| 88 | + head_size=self.kv_lora_rank + self.qk_rope_head_dim, |
| 89 | + scale=scale, |
| 90 | + num_kv_heads=1, |
| 91 | + cache_config=cache_config, |
| 92 | + quant_config=quant_config, |
| 93 | + prefix=f"{prefix}.attn", |
| 94 | + use_mla=True, |
| 95 | + # MLA Args |
| 96 | + q_lora_rank=self.q_lora_rank, |
| 97 | + kv_lora_rank=self.kv_lora_rank, |
| 98 | + qk_nope_head_dim=self.qk_nope_head_dim, |
| 99 | + qk_rope_head_dim=self.qk_rope_head_dim, |
| 100 | + qk_head_dim=self.qk_head_dim, |
| 101 | + v_head_dim=self.v_head_dim, |
| 102 | + kv_b_proj=self.kv_b_proj, |
| 103 | + ) |
| 104 | + |
| 105 | + self.prefix = prefix |
| 106 | + self.debug_layer_idx = int(self.prefix.split(".")[-2]) |
| 107 | + |
| 108 | + def forward_native( |
| 109 | + self, |
| 110 | + positions: torch.Tensor, |
| 111 | + hidden_states: torch.Tensor, |
| 112 | + ) -> torch.Tensor: |
| 113 | + q_c = None |
| 114 | + kv_lora = None |
| 115 | + |
| 116 | + if self.q_lora_rank is not None: |
| 117 | + assert self.fused_qkv_a_proj is not None, \ |
| 118 | + "fused_qkv_a_proj is required when q_lora_rank is not None" |
| 119 | + assert self.q_a_layernorm is not None, \ |
| 120 | + "q_a_layernorm is required when q_lora_rank is not None" |
| 121 | + assert self.q_b_proj is not None, \ |
| 122 | + "q_b_proj is required when q_lora_rank is not None" |
| 123 | + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] |
| 124 | + q_c, kv_lora = qkv_lora.split( |
| 125 | + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], |
| 126 | + dim=-1, |
| 127 | + ) |
| 128 | + q_c = self.q_a_layernorm(q_c) |
| 129 | + q = self.q_b_proj(q_c)[0] |
| 130 | + else: |
| 131 | + assert self.kv_a_proj_with_mqa is not None, \ |
| 132 | + "kv_a_proj_with_mqa is required when q_lora_rank is None" |
| 133 | + assert self.q_proj is not None, \ |
| 134 | + "q_proj is required when q_lora_rank is None" |
| 135 | + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] |
| 136 | + q = self.q_proj(hidden_states)[0] |
| 137 | + |
| 138 | + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], |
| 139 | + dim=-1) |
| 140 | + kv_c_normed = self.kv_a_layernorm(kv_c) |
| 141 | + |
| 142 | + q = q.view(-1, self.num_heads, self.qk_head_dim) |
| 143 | + # Add head dim of 1 to k_pe |
| 144 | + k_pe = k_pe.unsqueeze(1) |
| 145 | + |
| 146 | + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( |
| 147 | + positions, q[..., self.qk_nope_head_dim:], k_pe) |
| 148 | + |
| 149 | + attn_out = self.mla_attn( |
| 150 | + q, |
| 151 | + kv_c_normed, |
| 152 | + k_pe, |
| 153 | + output_shape=(hidden_states.shape[0], |
| 154 | + self.num_heads * self.v_head_dim)) |
| 155 | + return self.o_proj(attn_out)[0] |
| 156 | + |
| 157 | + def forward_cuda(self, *args, **kwargs): |
| 158 | + return self.forward_native(*args, **kwargs) |
0 commit comments