Skip to content

Commit db8e878

Browse files
feat/MLA (#2113)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3e8f40c commit db8e878

File tree

4 files changed

+315
-3
lines changed

4 files changed

+315
-3
lines changed

extensions/thunder/pretrain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from litgpt import Tokenizer
2424
from litgpt.args import EvalArgs, LogArgs, TrainArgs
2525
from litgpt.data import DataModule, TinyLlama
26-
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
26+
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP, MultiheadLatentAttention
2727
from litgpt.utils import (
2828
CLI,
2929
CycleIterator,
@@ -461,7 +461,7 @@ def init_weights(module, std):
461461

462462
# need a separate loop because `mod.proj` below is a `nn.Linear` too
463463
for mod in model.modules():
464-
if isinstance(mod, (LLaMAMLP, CausalSelfAttention)):
464+
if isinstance(mod, (LLaMAMLP, CausalSelfAttention, MultiheadLatentAttention)):
465465
mod.proj.reset_parameters = partial(init_weights, mod.proj, std=(1 / math.sqrt(n_embd) / n_layer))
466466

467467
if not isinstance(fabric.strategy, FSDPStrategy):

litgpt/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class Config:
9393
final_logit_softcapping: Optional[float] = None
9494
norm_1: bool = True
9595
norm_2: bool = True
96+
latent_attention: Optional[dict] = None
9697
# The base period of the RoPE embeddings for local attention.
9798
# If not provided, rope_theta will be used for both local and global attention.
9899
rope_local_base_freq: Optional[float] = None
@@ -133,6 +134,23 @@ def __post_init__(self):
133134
if self.rope_local_base_freq is not None and self.rope_indices is None:
134135
self.rope_indices = [1] * self.n_layer
135136

137+
if self.latent_attention is not None:
138+
self.q_lora_rank = self.latent_attention.get("q_lora_rank")
139+
self.kv_lora_rank = self.latent_attention.get("kv_lora_rank")
140+
self.qk_rope_head_dim = self.latent_attention.get("qk_rope_head_dim")
141+
self.qk_nope_head_dim = self.latent_attention.get("qk_nope_head_dim")
142+
self.v_head_dim = self.latent_attention.get("v_head_dim")
143+
assert (
144+
self.q_lora_rank
145+
and self.kv_lora_rank
146+
and self.qk_rope_head_dim
147+
and self.qk_nope_head_dim
148+
and self.v_head_dim
149+
) is not None
150+
assert self.n_head == self.n_query_groups, "Latent attention does not support MQA/GQA"
151+
self.qk_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim
152+
self.rope_n_elem = self.qk_rope_head_dim
153+
136154
@classmethod
137155
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
138156
if name not in name_to_config:

litgpt/model.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,11 @@ def __init__(
272272
)
273273

274274
self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps)
275-
self.attn = CausalSelfAttention(config, block_idx)
275+
self.attn = (
276+
CausalSelfAttention(config, block_idx)
277+
if not config.latent_attention
278+
else MultiheadLatentAttention(config, block_idx)
279+
)
276280
self.post_attention_norm = (
277281
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()
278282
)
@@ -549,6 +553,146 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwa
549553
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
550554

551555

556+
class MultiheadLatentAttention(nn.Module):
557+
def __init__(self, config: Config, block_idx: int) -> None:
558+
super().__init__()
559+
560+
self.q_a_proj = nn.Linear(config.n_embd, config.q_lora_rank, bias=config.attn_bias)
561+
self.q_a_norm = RMSNorm(config.q_lora_rank, eps=config.norm_eps)
562+
self.q_b_proj = nn.Linear(config.q_lora_rank, config.n_head * config.qk_head_dim, bias=config.bias)
563+
564+
self.kv_a_proj_with_mqa = nn.Linear(
565+
config.n_embd, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attn_bias
566+
)
567+
self.kv_a_norm = RMSNorm(config.kv_lora_rank, eps=config.norm_eps)
568+
self.kv_b_proj = nn.Linear(
569+
config.kv_lora_rank,
570+
config.n_query_groups * (config.qk_nope_head_dim + config.v_head_dim),
571+
bias=config.bias,
572+
)
573+
574+
# output projection
575+
self.proj = nn.Linear(config.n_head * config.v_head_dim, config.n_embd, bias=config.bias)
576+
# disabled by default
577+
self.kv_cache: Optional[KVCache] = None
578+
579+
self.config = config
580+
self.block_idx = block_idx
581+
582+
def forward(
583+
self,
584+
x: torch.Tensor,
585+
cos: torch.Tensor,
586+
sin: torch.Tensor,
587+
mask: Optional[torch.Tensor] = None,
588+
input_pos: Optional[torch.Tensor] = None,
589+
input_pos_maxp1: Optional[int] = None,
590+
) -> torch.Tensor:
591+
# Notation:
592+
# - B | batch size
593+
# - T | time-step (sequence length)
594+
# - C | model's embeddings size (n_embd)
595+
# - C* | attentions's embeddings size
596+
# - hs | head size
597+
# - nh_(q,k,v) | number of heads for query, key and value
598+
# - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads
599+
# alternative notation: num_kv_groups = n_query_groups
600+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
601+
602+
q = self.q_b_proj(self.q_a_norm(self.q_a_proj(x))) # (B, T, n_head * qk_head_dim)
603+
q = q.view(B, T, -1, self.config.qk_head_dim) # (B, T, n_head, qk_head_dim)
604+
q = q.transpose(1, 2) # (B, n_head, T, qk_head_dim)
605+
q_pass, q_rot = torch.split(q, [self.config.qk_nope_head_dim, self.config.qk_rope_head_dim], dim=-1)
606+
607+
compressed_kv = self.kv_a_proj_with_mqa(x) # (B, T, kv_lora_rank + qk_rope_head_dim)
608+
k_pass, k_rot = torch.split(compressed_kv, [self.config.kv_lora_rank, self.config.qk_rope_head_dim], dim=-1)
609+
610+
k_pass = self.kv_b_proj(self.kv_a_norm(k_pass))
611+
k_pass = k_pass.view(B, T, self.config.n_query_groups, -1)
612+
k_pass = k_pass.transpose(1, 2)
613+
614+
k_pass, v = torch.split(k_pass, [self.config.qk_nope_head_dim, self.config.v_head_dim], dim=-1)
615+
k_rot = k_rot.view(B, 1, T, self.config.qk_rope_head_dim) # (B, 1, T, qk_rope_head_dim)
616+
617+
# Unlike standard positional embeddings rotary embeddings must be applied at every layer.
618+
q_roped = apply_rope(q_rot, cos, sin)
619+
k_roped = apply_rope(k_rot, cos, sin)
620+
k_roped = k_roped.expand(*k_pass.shape[:-1], -1) # (B, n_head, T, qk_rope_head_dim)
621+
622+
q = torch.cat((q_pass, q_roped), dim=-1)
623+
k = torch.cat((k_pass, k_roped), dim=-1)
624+
625+
# Apply kv-cache during inference.
626+
if input_pos is not None:
627+
if not isinstance(self.kv_cache, KVCache):
628+
raise TypeError("You need to call `gpt.set_kv_cache()`")
629+
k, v = self.kv_cache(input_pos, k, v)
630+
if input_pos_maxp1 is not None:
631+
# Subselect along sequence dimension
632+
k = k[..., :input_pos_maxp1, :]
633+
v = v[..., :input_pos_maxp1, :]
634+
# k, v: (B, nh_k, input_pos_maxp1, hs)
635+
# If input_pos_maxp1 is None -> max_seq_length
636+
637+
# Grouped queries: balance the number of heads across all three matrices.
638+
# NOTE: flash attention requires it in training mode.
639+
# Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting.
640+
if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
641+
q_per_kv = self.config.n_head // self.config.n_query_groups
642+
k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs)
643+
v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs)
644+
645+
# Efficient attention using Flash Attention CUDA kernels.
646+
# NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.
647+
# ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)
648+
y = self.scaled_dot_product_attention(q, k, v, mask)
649+
650+
# Re-assemble all head outputs side by side.
651+
y = y.reshape(B, T, self.config.n_head * self.config.v_head_dim)
652+
653+
# Output projection.
654+
return self.proj(y) # (B, T, C)
655+
656+
def scaled_dot_product_attention(
657+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
658+
) -> torch.Tensor:
659+
scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.qk_head_dim)
660+
661+
# with softcapping we cannot use SDPA
662+
if self.config.attention_logit_softcapping is not None:
663+
scores = q @ k.mT * scale
664+
scores = do_softcapping(scores, self.config.attention_logit_softcapping)
665+
if mask is None:
666+
mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)
667+
mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)
668+
scores = scores + mask
669+
scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype)
670+
y = scores @ v
671+
else:
672+
y = F.scaled_dot_product_attention(
673+
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
674+
)
675+
return y.transpose(1, 2)
676+
677+
def build_kv_cache(
678+
self,
679+
batch_size: int,
680+
max_seq_length: int,
681+
rope_cache_length: Optional[int] = None,
682+
device: Optional[torch.device] = None,
683+
dtype: Optional[torch.dtype] = None,
684+
) -> "KVCache":
685+
v_shape = (batch_size, self.config.n_head, max_seq_length, self.config.v_head_dim)
686+
k_shape = (batch_size, self.config.n_head, max_seq_length, self.config.qk_head_dim)
687+
688+
if rope_cache_length is not None:
689+
print("Warning: `rope_cache_length` has no effect on MultiheadLatentAttention!")
690+
if self.config.rotary_percentage != 1.0:
691+
print("Warning: `rotary_percentage` has no effect on MultiheadLatentAttention!")
692+
693+
return KVCache(k_shape, v_shape, device=device, dtype=dtype)
694+
695+
552696
class GptNeoxMLP(nn.Module):
553697
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
554698
super().__init__()
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2+
3+
import pytest
4+
import torch
5+
from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM
6+
7+
from litgpt import Config
8+
from litgpt.model import MultiheadLatentAttention
9+
10+
11+
@torch.inference_mode()
12+
def test_multihead_latent_attention_kv_cache():
13+
"""Test KV cache functionality"""
14+
config = Config(
15+
block_size=32,
16+
n_embd=64,
17+
n_head=4,
18+
n_query_groups=4,
19+
head_size=16,
20+
latent_attention={
21+
"q_lora_rank": 32,
22+
"kv_lora_rank": 16,
23+
"qk_rope_head_dim": 8,
24+
"qk_nope_head_dim": 8,
25+
"v_head_dim": 16,
26+
},
27+
)
28+
29+
mla = MultiheadLatentAttention(config, block_idx=0)
30+
31+
# Build KV cache
32+
kv_cache = mla.build_kv_cache(batch_size=2, max_seq_length=32, device=torch.device("cpu"), dtype=torch.float32)
33+
34+
# Check cache shapes
35+
assert kv_cache.k.shape == (2, config.n_head, 32, config.qk_head_dim)
36+
assert kv_cache.v.shape == (2, config.n_head, 32, config.v_head_dim)
37+
38+
39+
@torch.inference_mode()
40+
def test_multihead_latent_attention_with_mask():
41+
"""Test attention with causal mask"""
42+
config = Config(
43+
n_embd=64,
44+
n_head=4,
45+
n_query_groups=4,
46+
head_size=16,
47+
latent_attention={
48+
"q_lora_rank": 32,
49+
"kv_lora_rank": 16,
50+
"qk_rope_head_dim": 8,
51+
"qk_nope_head_dim": 8,
52+
"v_head_dim": 16,
53+
},
54+
)
55+
56+
mla = MultiheadLatentAttention(config, block_idx=0)
57+
58+
batch_size, seq_len = 1, 8
59+
x = torch.randn(batch_size, seq_len, config.n_embd)
60+
cos = torch.randn(1, seq_len, config.qk_rope_head_dim)
61+
sin = torch.randn(1, seq_len, config.qk_rope_head_dim)
62+
63+
# Create causal mask
64+
mask = torch.ones(seq_len, seq_len, dtype=x.dtype).triu(diagonal=1)
65+
mask.masked_fill_(mask.bool(), float("-inf"))
66+
mask = mask.view(1, 1, seq_len, seq_len)
67+
68+
# Forward pass with mask
69+
output = mla(x, cos, sin, mask=mask)
70+
71+
assert output.shape == (batch_size, seq_len, config.n_embd)
72+
73+
74+
@torch.inference_mode()
75+
@pytest.mark.parametrize("batch_size", (1, 2))
76+
@pytest.mark.parametrize("seq_len", (8, 16))
77+
@pytest.mark.parametrize("device", [torch.device("cpu")])
78+
def test_multihead_latent_attention_litgpt_vs_hf(batch_size, seq_len, device):
79+
"""Test MLA litgpt vs hf"""
80+
config_litgpt = Config(
81+
n_embd=64,
82+
n_head=4,
83+
n_query_groups=4,
84+
head_size=16,
85+
norm_eps=1e-6,
86+
bias=False,
87+
latent_attention={
88+
"q_lora_rank": 32,
89+
"kv_lora_rank": 16,
90+
"qk_rope_head_dim": 8,
91+
"qk_nope_head_dim": 8,
92+
"v_head_dim": 16,
93+
},
94+
)
95+
96+
config_hf = DeepseekV3Config(
97+
padded_vocab_size=10000,
98+
num_hidden_layers=1,
99+
vocab_size=10000,
100+
hidden_size=64,
101+
num_attention_heads=4,
102+
num_key_value_heads=4,
103+
q_lora_rank=32,
104+
kv_lora_rank=16,
105+
qk_rope_head_dim=8,
106+
qk_nope_head_dim=8,
107+
v_head_dim=16,
108+
rope_interleave=False,
109+
)
110+
111+
mla_litgpt = MultiheadLatentAttention(config_litgpt, block_idx=0).to(device)
112+
model_hf = DeepseekV3ForCausalLM(config_hf).to(device)
113+
mla_hf = model_hf.model.layers[0].self_attn
114+
115+
mla_litgpt.eval()
116+
mla_hf.eval()
117+
118+
sync_weights(mla_litgpt, mla_hf)
119+
120+
hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)
121+
122+
# Prepare RoPE sin/cos tables
123+
rope_head_dim = config_litgpt.latent_attention["qk_rope_head_dim"]
124+
cos = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype)
125+
sin = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype)
126+
127+
causal_mask = torch.triu(
128+
torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=hidden_states.dtype), diagonal=1
129+
)
130+
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)
131+
132+
# Run forward passes
133+
output_litgpt = mla_litgpt(hidden_states, cos, sin)
134+
output_hf = mla_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask)[0]
135+
136+
assert torch.allclose(output_litgpt, output_hf, atol=1e-5)
137+
138+
139+
def sync_weights(litgpt_model, hf_model):
140+
"""Copies weights from lit-gpt model to HF model."""
141+
print("Synchronizing weights...")
142+
with torch.no_grad():
143+
hf_model.q_a_proj.weight.copy_(litgpt_model.q_a_proj.weight)
144+
hf_model.q_a_layernorm.weight.copy_(litgpt_model.q_a_norm.weight)
145+
hf_model.q_b_proj.weight.copy_(litgpt_model.q_b_proj.weight)
146+
hf_model.kv_a_proj_with_mqa.weight.copy_(litgpt_model.kv_a_proj_with_mqa.weight)
147+
hf_model.kv_a_layernorm.weight.copy_(litgpt_model.kv_a_norm.weight)
148+
hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight)
149+
hf_model.o_proj.weight.copy_(litgpt_model.proj.weight)
150+
print("Synchronization complete.")

0 commit comments

Comments
 (0)