diff --git a/litgpt/config.py b/litgpt/config.py index 97549a114d..da7d3ee5bb 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -86,7 +86,14 @@ class Config: mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" gelu_approximate: str = "none" n_expert: int = 0 + n_shared_expert: Optional[int] = None + n_expert_groups: Optional[int] = None + n_topk_groups: Optional[int] = None + n_topk_scores_per_group: Optional[int] = None n_expert_per_token: int = 0 + first_k_dense_replace: Optional[int] = None + routed_scaling_factor: float = 1.0 + norm_topk_prob: bool = False # GPT before/after blocks scale_embeddings: bool = False lm_head_bias: bool = False @@ -150,6 +157,13 @@ def __post_init__(self): assert self.n_head == self.n_query_groups, "Latent attention does not support MQA/GQA" self.qk_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim self.rope_n_elem = self.qk_rope_head_dim + if self.first_k_dense_replace is not None: + assert self.mlp_class_name == "LLaMAMoE" + if self.n_expert_groups is not None: + assert self.n_expert % self.n_expert_groups == 0 and self.n_expert_groups > 1 + assert self.n_topk_groups is not None + experts_per_group = self.n_expert // self.n_expert_groups + assert self.n_topk_scores_per_group is not None and self.n_topk_scores_per_group <= experts_per_group @classmethod def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: diff --git a/litgpt/model.py b/litgpt/model.py index 01ea83ad4a..1c4094f478 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -73,7 +73,9 @@ def reset_parameters(self) -> None: def _init_weights(self, module: nn.Module) -> None: """Meant to be used with `gpt.apply(gpt._init_weights)`.""" - if isinstance(module, nn.Linear): + if isinstance(module, GroupedTopkRouter): + torch.nn.init.normal_(module.weight.data, mean=0.0, std=0.02) + elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) @@ -286,6 +288,8 @@ def __init__( else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)) ) self.mlp = config.mlp_class(config) + if config.first_k_dense_replace is not None and block_idx < config.first_k_dense_replace: + self.mlp = LLaMAMLP(config) self.post_mlp_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() ) @@ -734,10 +738,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LLaMAMoE(nn.Module): def __init__(self, config: Config) -> None: super().__init__() - self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False) + self.gate = ( + nn.Linear(config.n_embd, config.n_expert, bias=False) + if not config.n_expert_groups + else GroupedTopkRouter(config) + ) self.experts = nn.ModuleList( LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert) ) + if config.n_shared_expert: + self.shared_experts = LLaMAMLP( + config, intermediate_size=config.moe_intermediate_size * config.n_shared_expert + ) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -746,17 +758,71 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: See also figure 1 in https://arxiv.org/abs/2211.15841 """ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + residual_x = x.clone() x = x.view(-1, C) # (B*T, C) - router = self.gate(x) # (B*T, n_expert) - probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token) - probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype) + if not self.config.n_expert_groups: + router = self.gate(x) # (B*T, n_expert) + probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token) + probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype) + else: + probs, indices = self.gate(x) + if self.config.routed_scaling_factor != 1.0: + probs = probs * self.config.routed_scaling_factor masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device) masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token) y = torch.zeros_like(x) # (B*T, C) for mask, expert in zip(masks, self.experts): token_idx, expert_idx = torch.where(mask) y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx]) - return y.view(B, T, C) + + y = y.view(B, T, C) + if self.config.n_shared_expert: + y = y + self.shared_experts(residual_x) + return y + + +class GroupedTopkRouter(nn.Module): + """ + Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py. + DeepseekV3TopkRouter class. + """ + + def __init__(self, config: Config) -> None: + super().__init__() + self.config = config + self.weight = nn.Parameter(torch.empty(config.n_expert, config.n_embd)) + self.register_buffer("e_score_correction_bias", torch.zeros(config.n_expert)) + + @torch.no_grad() + def get_topk_indices(self, scores: torch.Tensor) -> torch.Tensor: + scores_for_choice = scores.view(-1, self.config.n_expert) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups) + .topk(self.config.n_topk_scores_per_group, dim=-1)[0] # Top k scores for each group + .sum(dim=-1) + ) + + group_idx = torch.topk(group_scores, k=self.config.n_topk_groups, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups) + .reshape(-1, self.config.n_expert) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.config.n_expert_per_token, dim=-1, sorted=False)[1] + return topk_indices + + def forward(self, x: torch.Tensor) -> torch.Tensor: + router_logits = F.linear(x.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + if self.config.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + return topk_weights, topk_indices def build_rope_cache( diff --git a/tests/test_deepseek_moe.py b/tests/test_deepseek_moe.py new file mode 100644 index 0000000000..03d867cd8e --- /dev/null +++ b/tests/test_deepseek_moe.py @@ -0,0 +1,114 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import pytest +import torch +from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM + +from litgpt import Config +from litgpt.model import GPT, LLaMAMLP + + +@torch.inference_mode() +@pytest.mark.parametrize("batch_size", (1, 2)) +@pytest.mark.parametrize("seq_len", (8, 16)) +@pytest.mark.parametrize("device", [torch.device("cpu")]) +def test_deepseek_moe_litgpt_vs_hf(batch_size, seq_len, device): + """Test MOE litgpt vs hf""" + config_litgpt = Config( + padded_vocab_size=10000, + n_layer=2, + vocab_size=10000, + n_embd=64, + n_head=4, + n_query_groups=4, + head_size=16, + norm_eps=1e-6, + bias=False, + latent_attention={ + "q_lora_rank": 32, + "kv_lora_rank": 16, + "qk_rope_head_dim": 8, + "qk_nope_head_dim": 8, + "v_head_dim": 16, + }, + n_expert=16, + n_shared_expert=1, + n_expert_per_token=2, + n_expert_groups=4, + n_topk_groups=2, + n_topk_scores_per_group=2, # Note: Deepseek hardcodes this to `2` + first_k_dense_replace=1, + routed_scaling_factor=2.5, + norm_topk_prob=True, + moe_intermediate_size=20, + mlp_class_name="LLaMAMoE", + ) + + config_hf = DeepseekV3Config( + padded_vocab_size=10000, + num_hidden_layers=2, + vocab_size=10000, + hidden_size=64, + num_attention_heads=4, + num_key_value_heads=4, + q_lora_rank=32, + kv_lora_rank=16, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=16, + rope_interleave=False, + first_k_dense_replace=1, + routed_scaling_factor=2.5, + norm_topk_prob=True, + n_routed_experts=config_litgpt.n_expert, + n_shared_experts=config_litgpt.n_shared_expert, + num_experts_per_tok=config_litgpt.n_expert_per_token, + n_group=config_litgpt.n_expert_groups, + topk_group=config_litgpt.n_topk_groups, + moe_intermediate_size=config_litgpt.moe_intermediate_size, + ) + + model_litgpt = GPT(config_litgpt).to(device) + model_litgpt.apply(model_litgpt._init_weights) + + mlp_litgpt = model_litgpt.transformer.h[0].mlp + assert isinstance(mlp_litgpt, LLaMAMLP) # Test first_k_dense_replace (k=1) + + moe_litgpt = model_litgpt.transformer.h[1].mlp + model_hf = DeepseekV3ForCausalLM(config_hf).to(device) + moe_hf = model_hf.model.layers[1].mlp + + moe_litgpt.eval() + moe_hf.eval() + + sync_weights(moe_litgpt, moe_hf) + + hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device) + + output_litgpt = moe_litgpt(hidden_states) + output_hf = moe_hf(hidden_states) + + assert torch.allclose(output_litgpt, output_hf, atol=1e-5) + + +def sync_weights(litgpt_model, hf_model): + print("Synchronizing MoE weights...") + + with torch.no_grad(): + if hasattr(litgpt_model, "gate"): + if hasattr(litgpt_model.gate, "weight"): + hf_model.gate.weight.copy_(litgpt_model.gate.weight) + if hasattr(litgpt_model.gate, "e_score_correction_bias"): + hf_model.gate.e_score_correction_bias.copy_(litgpt_model.gate.e_score_correction_bias) + + for i, (litgpt_expert, hf_expert) in enumerate(zip(litgpt_model.experts, hf_model.experts)): + hf_expert.gate_proj.weight.copy_(litgpt_expert.fc_1.weight) + hf_expert.up_proj.weight.copy_(litgpt_expert.fc_2.weight) + hf_expert.down_proj.weight.copy_(litgpt_expert.proj.weight) + + if hasattr(litgpt_model, "shared_experts") and hasattr(hf_model, "shared_experts"): + hf_model.shared_experts.gate_proj.weight.copy_(litgpt_model.shared_experts.fc_1.weight) + hf_model.shared_experts.up_proj.weight.copy_(litgpt_model.shared_experts.fc_2.weight) + hf_model.shared_experts.down_proj.weight.copy_(litgpt_model.shared_experts.proj.weight) + + print("MoE weight synchronization complete.")