Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,14 @@ class Config:
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
gelu_approximate: str = "none"
n_expert: int = 0
n_shared_expert: Optional[int] = None
n_expert_groups: Optional[int] = None
n_topk_groups: Optional[int] = None
n_topk_scores_per_group: Optional[int] = None
n_expert_per_token: int = 0
first_k_dense_replace: Optional[int] = None
routed_scaling_factor: float = 1.0
norm_topk_prob: bool = False
# GPT before/after blocks
scale_embeddings: bool = False
lm_head_bias: bool = False
Expand Down Expand Up @@ -150,6 +157,13 @@ def __post_init__(self):
assert self.n_head == self.n_query_groups, "Latent attention does not support MQA/GQA"
self.qk_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim
self.rope_n_elem = self.qk_rope_head_dim
if self.first_k_dense_replace is not None:
assert self.mlp_class_name == "LLaMAMoE"
if self.n_expert_groups is not None:
assert self.n_expert % self.n_expert_groups == 0 and self.n_expert_groups > 1
assert self.n_topk_groups is not None
experts_per_group = self.n_expert // self.n_expert_groups
assert self.n_topk_scores_per_group is not None and self.n_topk_scores_per_group <= experts_per_group

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
Expand Down
78 changes: 72 additions & 6 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def reset_parameters(self) -> None:

def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
if isinstance(module, nn.Linear):
if isinstance(module, GroupedTopkRouter):
torch.nn.init.normal_(module.weight.data, mean=0.0, std=0.02)
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
Expand Down Expand Up @@ -286,6 +288,8 @@ def __init__(
else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps))
)
self.mlp = config.mlp_class(config)
if config.first_k_dense_replace is not None and block_idx < config.first_k_dense_replace:
self.mlp = LLaMAMLP(config)
self.post_mlp_norm = (
config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()
)
Expand Down Expand Up @@ -734,10 +738,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class LLaMAMoE(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
self.gate = (
nn.Linear(config.n_embd, config.n_expert, bias=False)
if not config.n_expert_groups
else GroupedTopkRouter(config)
)
self.experts = nn.ModuleList(
LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)
)
if config.n_shared_expert:
self.shared_experts = LLaMAMLP(
config, intermediate_size=config.moe_intermediate_size * config.n_shared_expert
)
self.config = config

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -746,17 +758,71 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
See also figure 1 in https://arxiv.org/abs/2211.15841
"""
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
residual_x = x.clone()
x = x.view(-1, C) # (B*T, C)
router = self.gate(x) # (B*T, n_expert)
probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token)
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
if not self.config.n_expert_groups:
router = self.gate(x) # (B*T, n_expert)
probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token)
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
else:
probs, indices = self.gate(x)
if self.config.routed_scaling_factor != 1.0:
probs = probs * self.config.routed_scaling_factor
masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device)
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
y = torch.zeros_like(x) # (B*T, C)
for mask, expert in zip(masks, self.experts):
token_idx, expert_idx = torch.where(mask)
y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
return y.view(B, T, C)

y = y.view(B, T, C)
if self.config.n_shared_expert:
y = y + self.shared_experts(residual_x)
return y


class GroupedTopkRouter(nn.Module):
"""
Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py.
DeepseekV3TopkRouter class.
"""

def __init__(self, config: Config) -> None:
super().__init__()
self.config = config
self.weight = nn.Parameter(torch.empty(config.n_expert, config.n_embd))
self.register_buffer("e_score_correction_bias", torch.zeros(config.n_expert))

@torch.no_grad()
def get_topk_indices(self, scores: torch.Tensor) -> torch.Tensor:
scores_for_choice = scores.view(-1, self.config.n_expert) + self.e_score_correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups)
.topk(self.config.n_topk_scores_per_group, dim=-1)[0] # Top k scores for each group
.sum(dim=-1)
)

group_idx = torch.topk(group_scores, k=self.config.n_topk_groups, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups)
.reshape(-1, self.config.n_expert)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.config.n_expert_per_token, dim=-1, sorted=False)[1]
return topk_indices

def forward(self, x: torch.Tensor) -> torch.Tensor:
router_logits = F.linear(x.type(torch.float32), self.weight.type(torch.float32))
scores = router_logits.sigmoid()
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
if self.config.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
return topk_weights, topk_indices


def build_rope_cache(
Expand Down
114 changes: 114 additions & 0 deletions tests/test_deepseek_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import pytest
import torch
from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM

from litgpt import Config
from litgpt.model import GPT, LLaMAMLP


@torch.inference_mode()
@pytest.mark.parametrize("batch_size", (1, 2))
@pytest.mark.parametrize("seq_len", (8, 16))
@pytest.mark.parametrize("device", [torch.device("cpu")])
def test_deepseek_moe_litgpt_vs_hf(batch_size, seq_len, device):
"""Test MOE litgpt vs hf"""
config_litgpt = Config(
padded_vocab_size=10000,
n_layer=2,
vocab_size=10000,
n_embd=64,
n_head=4,
n_query_groups=4,
head_size=16,
norm_eps=1e-6,
bias=False,
latent_attention={
"q_lora_rank": 32,
"kv_lora_rank": 16,
"qk_rope_head_dim": 8,
"qk_nope_head_dim": 8,
"v_head_dim": 16,
},
n_expert=16,
n_shared_expert=1,
n_expert_per_token=2,
n_expert_groups=4,
n_topk_groups=2,
n_topk_scores_per_group=2, # Note: Deepseek hardcodes this to `2`
first_k_dense_replace=1,
routed_scaling_factor=2.5,
norm_topk_prob=True,
moe_intermediate_size=20,
mlp_class_name="LLaMAMoE",
)

config_hf = DeepseekV3Config(
padded_vocab_size=10000,
num_hidden_layers=2,
vocab_size=10000,
hidden_size=64,
num_attention_heads=4,
num_key_value_heads=4,
q_lora_rank=32,
kv_lora_rank=16,
qk_rope_head_dim=8,
qk_nope_head_dim=8,
v_head_dim=16,
rope_interleave=False,
first_k_dense_replace=1,
routed_scaling_factor=2.5,
norm_topk_prob=True,
n_routed_experts=config_litgpt.n_expert,
n_shared_experts=config_litgpt.n_shared_expert,
num_experts_per_tok=config_litgpt.n_expert_per_token,
n_group=config_litgpt.n_expert_groups,
topk_group=config_litgpt.n_topk_groups,
moe_intermediate_size=config_litgpt.moe_intermediate_size,
)

model_litgpt = GPT(config_litgpt).to(device)
model_litgpt.apply(model_litgpt._init_weights)

mlp_litgpt = model_litgpt.transformer.h[0].mlp
assert isinstance(mlp_litgpt, LLaMAMLP) # Test first_k_dense_replace (k=1)

moe_litgpt = model_litgpt.transformer.h[1].mlp
model_hf = DeepseekV3ForCausalLM(config_hf).to(device)
moe_hf = model_hf.model.layers[1].mlp

moe_litgpt.eval()
moe_hf.eval()

sync_weights(moe_litgpt, moe_hf)

hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)

output_litgpt = moe_litgpt(hidden_states)
output_hf = moe_hf(hidden_states)

assert torch.allclose(output_litgpt, output_hf, atol=1e-5)


def sync_weights(litgpt_model, hf_model):
print("Synchronizing MoE weights...")

with torch.no_grad():
if hasattr(litgpt_model, "gate"):
if hasattr(litgpt_model.gate, "weight"):
hf_model.gate.weight.copy_(litgpt_model.gate.weight)
if hasattr(litgpt_model.gate, "e_score_correction_bias"):
hf_model.gate.e_score_correction_bias.copy_(litgpt_model.gate.e_score_correction_bias)

for i, (litgpt_expert, hf_expert) in enumerate(zip(litgpt_model.experts, hf_model.experts)):
hf_expert.gate_proj.weight.copy_(litgpt_expert.fc_1.weight)
hf_expert.up_proj.weight.copy_(litgpt_expert.fc_2.weight)
hf_expert.down_proj.weight.copy_(litgpt_expert.proj.weight)

if hasattr(litgpt_model, "shared_experts") and hasattr(hf_model, "shared_experts"):
hf_model.shared_experts.gate_proj.weight.copy_(litgpt_model.shared_experts.fc_1.weight)
hf_model.shared_experts.up_proj.weight.copy_(litgpt_model.shared_experts.fc_2.weight)
hf_model.shared_experts.down_proj.weight.copy_(litgpt_model.shared_experts.proj.weight)

print("MoE weight synchronization complete.")
Loading