Skip to content

Commit d993cf1

Browse files
ysjprojectspre-commit-ci[bot]KaelanDtBorda
authored
[FEAT] Add Grouped Topk Routing to LLaMAMoE (Based on DeepseekV3MoE) (#2134)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: KaelanDt <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 062fff2 commit d993cf1

File tree

3 files changed

+200
-6
lines changed

3 files changed

+200
-6
lines changed

litgpt/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,14 @@ class Config:
8686
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
8787
gelu_approximate: str = "none"
8888
n_expert: int = 0
89+
n_shared_expert: Optional[int] = None
90+
n_expert_groups: Optional[int] = None
91+
n_topk_groups: Optional[int] = None
92+
n_topk_scores_per_group: Optional[int] = None
8993
n_expert_per_token: int = 0
94+
first_k_dense_replace: Optional[int] = None
95+
routed_scaling_factor: float = 1.0
96+
norm_topk_prob: bool = False
9097
# GPT before/after blocks
9198
scale_embeddings: bool = False
9299
lm_head_bias: bool = False
@@ -150,6 +157,13 @@ def __post_init__(self):
150157
assert self.n_head == self.n_query_groups, "Latent attention does not support MQA/GQA"
151158
self.qk_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim
152159
self.rope_n_elem = self.qk_rope_head_dim
160+
if self.first_k_dense_replace is not None:
161+
assert self.mlp_class_name == "LLaMAMoE"
162+
if self.n_expert_groups is not None:
163+
assert self.n_expert % self.n_expert_groups == 0 and self.n_expert_groups > 1
164+
assert self.n_topk_groups is not None
165+
experts_per_group = self.n_expert // self.n_expert_groups
166+
assert self.n_topk_scores_per_group is not None and self.n_topk_scores_per_group <= experts_per_group
153167

154168
@classmethod
155169
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:

litgpt/model.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def reset_parameters(self) -> None:
7373

7474
def _init_weights(self, module: nn.Module) -> None:
7575
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
76-
if isinstance(module, nn.Linear):
76+
if isinstance(module, GroupedTopkRouter):
77+
torch.nn.init.normal_(module.weight.data, mean=0.0, std=0.02)
78+
elif isinstance(module, nn.Linear):
7779
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
7880
if module.bias is not None:
7981
torch.nn.init.zeros_(module.bias)
@@ -286,6 +288,8 @@ def __init__(
286288
else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps))
287289
)
288290
self.mlp = config.mlp_class(config)
291+
if config.first_k_dense_replace is not None and block_idx < config.first_k_dense_replace:
292+
self.mlp = LLaMAMLP(config)
289293
self.post_mlp_norm = (
290294
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()
291295
)
@@ -734,10 +738,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
734738
class LLaMAMoE(nn.Module):
735739
def __init__(self, config: Config) -> None:
736740
super().__init__()
737-
self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
741+
self.gate = (
742+
nn.Linear(config.n_embd, config.n_expert, bias=False)
743+
if not config.n_expert_groups
744+
else GroupedTopkRouter(config)
745+
)
738746
self.experts = nn.ModuleList(
739747
LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)
740748
)
749+
if config.n_shared_expert:
750+
self.shared_experts = LLaMAMLP(
751+
config, intermediate_size=config.moe_intermediate_size * config.n_shared_expert
752+
)
741753
self.config = config
742754

743755
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -746,17 +758,71 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
746758
See also figure 1 in https://arxiv.org/abs/2211.15841
747759
"""
748760
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
761+
residual_x = x.clone()
749762
x = x.view(-1, C) # (B*T, C)
750-
router = self.gate(x) # (B*T, n_expert)
751-
probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token)
752-
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
763+
if not self.config.n_expert_groups:
764+
router = self.gate(x) # (B*T, n_expert)
765+
probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token)
766+
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
767+
else:
768+
probs, indices = self.gate(x)
769+
if self.config.routed_scaling_factor != 1.0:
770+
probs = probs * self.config.routed_scaling_factor
753771
masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device)
754772
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
755773
y = torch.zeros_like(x) # (B*T, C)
756774
for mask, expert in zip(masks, self.experts):
757775
token_idx, expert_idx = torch.where(mask)
758776
y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
759-
return y.view(B, T, C)
777+
778+
y = y.view(B, T, C)
779+
if self.config.n_shared_expert:
780+
y = y + self.shared_experts(residual_x)
781+
return y
782+
783+
784+
class GroupedTopkRouter(nn.Module):
785+
"""
786+
Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py.
787+
DeepseekV3TopkRouter class.
788+
"""
789+
790+
def __init__(self, config: Config) -> None:
791+
super().__init__()
792+
self.config = config
793+
self.weight = nn.Parameter(torch.empty(config.n_expert, config.n_embd))
794+
self.register_buffer("e_score_correction_bias", torch.zeros(config.n_expert))
795+
796+
@torch.no_grad()
797+
def get_topk_indices(self, scores: torch.Tensor) -> torch.Tensor:
798+
scores_for_choice = scores.view(-1, self.config.n_expert) + self.e_score_correction_bias.unsqueeze(0)
799+
group_scores = (
800+
scores_for_choice.view(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups)
801+
.topk(self.config.n_topk_scores_per_group, dim=-1)[0] # Top k scores for each group
802+
.sum(dim=-1)
803+
)
804+
805+
group_idx = torch.topk(group_scores, k=self.config.n_topk_groups, dim=-1, sorted=False)[1]
806+
group_mask = torch.zeros_like(group_scores)
807+
group_mask.scatter_(1, group_idx, 1)
808+
score_mask = (
809+
group_mask.unsqueeze(-1)
810+
.expand(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups)
811+
.reshape(-1, self.config.n_expert)
812+
)
813+
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
814+
topk_indices = torch.topk(scores_for_choice, k=self.config.n_expert_per_token, dim=-1, sorted=False)[1]
815+
return topk_indices
816+
817+
def forward(self, x: torch.Tensor) -> torch.Tensor:
818+
router_logits = F.linear(x.type(torch.float32), self.weight.type(torch.float32))
819+
scores = router_logits.sigmoid()
820+
topk_indices = self.get_topk_indices(scores)
821+
topk_weights = scores.gather(1, topk_indices)
822+
if self.config.norm_topk_prob:
823+
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
824+
topk_weights /= denominator
825+
return topk_weights, topk_indices
760826

761827

762828
def build_rope_cache(

tests/test_deepseek_moe.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2+
3+
import pytest
4+
import torch
5+
from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM
6+
7+
from litgpt import Config
8+
from litgpt.model import GPT, LLaMAMLP
9+
10+
11+
@torch.inference_mode()
12+
@pytest.mark.parametrize("batch_size", (1, 2))
13+
@pytest.mark.parametrize("seq_len", (8, 16))
14+
@pytest.mark.parametrize("device", [torch.device("cpu")])
15+
def test_deepseek_moe_litgpt_vs_hf(batch_size, seq_len, device):
16+
"""Test MOE litgpt vs hf"""
17+
config_litgpt = Config(
18+
padded_vocab_size=10000,
19+
n_layer=2,
20+
vocab_size=10000,
21+
n_embd=64,
22+
n_head=4,
23+
n_query_groups=4,
24+
head_size=16,
25+
norm_eps=1e-6,
26+
bias=False,
27+
latent_attention={
28+
"q_lora_rank": 32,
29+
"kv_lora_rank": 16,
30+
"qk_rope_head_dim": 8,
31+
"qk_nope_head_dim": 8,
32+
"v_head_dim": 16,
33+
},
34+
n_expert=16,
35+
n_shared_expert=1,
36+
n_expert_per_token=2,
37+
n_expert_groups=4,
38+
n_topk_groups=2,
39+
n_topk_scores_per_group=2, # Note: Deepseek hardcodes this to `2`
40+
first_k_dense_replace=1,
41+
routed_scaling_factor=2.5,
42+
norm_topk_prob=True,
43+
moe_intermediate_size=20,
44+
mlp_class_name="LLaMAMoE",
45+
)
46+
47+
config_hf = DeepseekV3Config(
48+
padded_vocab_size=10000,
49+
num_hidden_layers=2,
50+
vocab_size=10000,
51+
hidden_size=64,
52+
num_attention_heads=4,
53+
num_key_value_heads=4,
54+
q_lora_rank=32,
55+
kv_lora_rank=16,
56+
qk_rope_head_dim=8,
57+
qk_nope_head_dim=8,
58+
v_head_dim=16,
59+
rope_interleave=False,
60+
first_k_dense_replace=1,
61+
routed_scaling_factor=2.5,
62+
norm_topk_prob=True,
63+
n_routed_experts=config_litgpt.n_expert,
64+
n_shared_experts=config_litgpt.n_shared_expert,
65+
num_experts_per_tok=config_litgpt.n_expert_per_token,
66+
n_group=config_litgpt.n_expert_groups,
67+
topk_group=config_litgpt.n_topk_groups,
68+
moe_intermediate_size=config_litgpt.moe_intermediate_size,
69+
)
70+
71+
model_litgpt = GPT(config_litgpt).to(device)
72+
model_litgpt.apply(model_litgpt._init_weights)
73+
74+
mlp_litgpt = model_litgpt.transformer.h[0].mlp
75+
assert isinstance(mlp_litgpt, LLaMAMLP) # Test first_k_dense_replace (k=1)
76+
77+
moe_litgpt = model_litgpt.transformer.h[1].mlp
78+
model_hf = DeepseekV3ForCausalLM(config_hf).to(device)
79+
moe_hf = model_hf.model.layers[1].mlp
80+
81+
moe_litgpt.eval()
82+
moe_hf.eval()
83+
84+
sync_weights(moe_litgpt, moe_hf)
85+
86+
hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)
87+
88+
output_litgpt = moe_litgpt(hidden_states)
89+
output_hf = moe_hf(hidden_states)
90+
91+
assert torch.allclose(output_litgpt, output_hf, atol=1e-5)
92+
93+
94+
def sync_weights(litgpt_model, hf_model):
95+
print("Synchronizing MoE weights...")
96+
97+
with torch.no_grad():
98+
if hasattr(litgpt_model, "gate"):
99+
if hasattr(litgpt_model.gate, "weight"):
100+
hf_model.gate.weight.copy_(litgpt_model.gate.weight)
101+
if hasattr(litgpt_model.gate, "e_score_correction_bias"):
102+
hf_model.gate.e_score_correction_bias.copy_(litgpt_model.gate.e_score_correction_bias)
103+
104+
for i, (litgpt_expert, hf_expert) in enumerate(zip(litgpt_model.experts, hf_model.experts)):
105+
hf_expert.gate_proj.weight.copy_(litgpt_expert.fc_1.weight)
106+
hf_expert.up_proj.weight.copy_(litgpt_expert.fc_2.weight)
107+
hf_expert.down_proj.weight.copy_(litgpt_expert.proj.weight)
108+
109+
if hasattr(litgpt_model, "shared_experts") and hasattr(hf_model, "shared_experts"):
110+
hf_model.shared_experts.gate_proj.weight.copy_(litgpt_model.shared_experts.fc_1.weight)
111+
hf_model.shared_experts.up_proj.weight.copy_(litgpt_model.shared_experts.fc_2.weight)
112+
hf_model.shared_experts.down_proj.weight.copy_(litgpt_model.shared_experts.proj.weight)
113+
114+
print("MoE weight synchronization complete.")

0 commit comments

Comments
 (0)