Skip to content

Commit 0864027

Browse files
ysjprojectsshijie.yupre-commit-ci[bot]
authored andcommitted
Qwen3 MoE (Lightning-AI#2060)
Co-authored-by: shijie.yu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2d4c988 commit 0864027

File tree

8 files changed

+208
-16
lines changed

8 files changed

+208
-16
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ Every model is written from scratch to maximize performance and remove layers of
151151
| QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) |
152152
| QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
153153
| Qwen3 | 0.6B, 1.7B, 4B, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
154+
| Qwen3 MoE | 30B, 235B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
154155
| R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) |
155156
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
156157
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |

litgpt/adapter_v2.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,12 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
158158

159159

160160
class LLaMAMLP(litgpt.model.LLaMAMLP):
161-
def __init__(self, config: Config) -> None:
161+
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
162162
nn.Module.__init__(self)
163-
self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
164-
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
165-
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)
163+
self.intermediate_size = intermediate_size or config.intermediate_size
164+
self.fc_1 = AdapterV2Linear(config.n_embd, self.intermediate_size, bias=config.bias)
165+
self.fc_2 = AdapterV2Linear(config.n_embd, self.intermediate_size, bias=config.bias)
166+
self.proj = AdapterV2Linear(self.intermediate_size, config.n_embd, bias=config.bias)
166167
self.config = config
167168

168169
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
@@ -191,7 +192,9 @@ class LLaMAMoE(litgpt.model.LLaMAMoE):
191192
def __init__(self, config: Config) -> None:
192193
nn.Module.__init__(self)
193194
self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False)
194-
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
195+
self.experts = nn.ModuleList(
196+
LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)
197+
)
195198
self.config = config
196199

197200
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:

litgpt/config.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,9 @@ def norm_class(self) -> Type:
25832583
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
25842584
configs.append(copy)
25852585

2586+
##########
2587+
# QwQ
2588+
##########
25862589
qwq = [
25872590
# https://huggingface.co/Qwen/QwQ-32B/blob/main/config.json
25882591
dict(
@@ -2630,6 +2633,9 @@ def norm_class(self) -> Type:
26302633

26312634
configs.extend(qwq)
26322635

2636+
##########
2637+
# Qwen3
2638+
##########
26332639
qwen_3 = [
26342640
# https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json
26352641
dict(
@@ -2771,6 +2777,85 @@ def norm_class(self) -> Type:
27712777
]
27722778
configs.extend(qwen_3_32b)
27732779

2780+
qwen_3_moe = [
2781+
# https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/config.json
2782+
dict(
2783+
name="Qwen3-30B-A3B",
2784+
hf_config=dict(org="Qwen", name="Qwen3-30B-A3B"),
2785+
block_size=40960,
2786+
head_size=128,
2787+
vocab_size=151643,
2788+
padded_vocab_size=151936,
2789+
n_layer=48,
2790+
n_head=32,
2791+
n_embd=2048,
2792+
n_query_groups=4,
2793+
rotary_percentage=1.0,
2794+
parallel_residual=False,
2795+
bias=False,
2796+
norm_class_name="RMSNorm",
2797+
mlp_class_name="LLaMAMoE",
2798+
intermediate_size=6144,
2799+
moe_intermediate_size=768,
2800+
norm_eps=1e-6,
2801+
rope_base=1000000,
2802+
norm_qk=True,
2803+
n_expert=128,
2804+
n_expert_per_token=8,
2805+
),
2806+
# https://huggingface.co/Qwen/Qwen3-30B-A3B-Base/blob/main/config.json
2807+
dict(
2808+
name="Qwen3-30B-A3B-Base",
2809+
hf_config=dict(org="Qwen", name="Qwen3-30B-A3B-Base"),
2810+
block_size=40960,
2811+
head_size=128,
2812+
vocab_size=151643,
2813+
padded_vocab_size=151936,
2814+
n_layer=48,
2815+
n_head=32,
2816+
n_embd=2048,
2817+
n_query_groups=4,
2818+
rotary_percentage=1.0,
2819+
parallel_residual=False,
2820+
bias=False,
2821+
norm_class_name="RMSNorm",
2822+
mlp_class_name="LLaMAMoE",
2823+
intermediate_size=6144,
2824+
moe_intermediate_size=768,
2825+
norm_eps=1e-6,
2826+
rope_base=1000000,
2827+
norm_qk=True,
2828+
n_expert=128,
2829+
n_expert_per_token=8,
2830+
),
2831+
# https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json
2832+
dict(
2833+
name="Qwen3-235B-A22B",
2834+
hf_config=dict(org="Qwen", name="Qwen3-235B-A22B"),
2835+
block_size=40960,
2836+
head_size=128,
2837+
vocab_size=151643,
2838+
padded_vocab_size=151936,
2839+
n_layer=94,
2840+
n_head=64,
2841+
n_embd=4096,
2842+
n_query_groups=4,
2843+
rotary_percentage=1.0,
2844+
parallel_residual=False,
2845+
bias=False,
2846+
norm_class_name="RMSNorm",
2847+
mlp_class_name="LLaMAMoE",
2848+
intermediate_size=12288,
2849+
moe_intermediate_size=1536,
2850+
norm_eps=1e-6,
2851+
rope_base=1000000,
2852+
norm_qk=True,
2853+
n_expert=128,
2854+
n_expert_per_token=8,
2855+
),
2856+
]
2857+
configs.extend(qwen_3_moe)
2858+
27742859

27752860
#############
27762861
# Salamandra

litgpt/lora.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,12 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
609609

610610

611611
class LLaMAMLP(litgpt.model.LLaMAMLP):
612-
def __init__(self, config: Config) -> None:
612+
def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:
613613
nn.Module.__init__(self)
614-
self.fc_1 = create_lora_linear(config, config.n_embd, config.intermediate_size)
615-
self.fc_2 = create_lora_linear(config, config.n_embd, config.intermediate_size)
616-
self.proj = create_lora_linear(config, config.intermediate_size, config.n_embd)
614+
self.intermediate_size = intermediate_size or config.intermediate_size
615+
self.fc_1 = create_lora_linear(config, config.n_embd, self.intermediate_size)
616+
self.fc_2 = create_lora_linear(config, config.n_embd, self.intermediate_size)
617+
self.proj = create_lora_linear(config, self.intermediate_size, config.n_embd)
617618
self.config = config
618619

619620
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
@@ -642,7 +643,9 @@ class LLaMAMoE(litgpt.model.LLaMAMoE):
642643
def __init__(self, config: Config) -> None:
643644
nn.Module.__init__(self)
644645
self.gate = create_lora_linear(config, config.n_embd, config.n_expert, bias=False)
645-
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
646+
self.experts = nn.ModuleList(
647+
LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)
648+
)
646649
self.config = config
647650

648651
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,12 +652,28 @@ def copy_weights_qwen_3(
652652
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
653653
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
654654
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
655-
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
656-
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
657-
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
658655
"model.norm.weight": "transformer.ln_f.weight",
659656
"lm_head.weight": "lm_head.weight",
660657
}
658+
if config.mlp_class_name == "LLaMAMoE":
659+
weight_map.update(
660+
{
661+
"model.layers.{}.mlp.experts.{}.gate_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight",
662+
"model.layers.{}.mlp.experts.{}.up_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight",
663+
"model.layers.{}.mlp.experts.{}.down_proj.weight": "transformer.h.{}.mlp.experts.{}.proj.weight",
664+
"model.layers.{}.mlp.gate.weight": "transformer.h.{}.mlp.gate.weight",
665+
}
666+
)
667+
elif config.mlp_class_name == "LLaMAMLP":
668+
weight_map.update(
669+
{
670+
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
671+
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
672+
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
673+
}
674+
)
675+
else:
676+
raise NotImplementedError
661677

662678
if progress_per_file is not None:
663679
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))

litgpt/scripts/convert_lit_checkpoint.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,12 +465,28 @@ def copy_weights_qwen_3(
465465
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
466466
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
467467
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
468-
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
469-
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
470-
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
471468
"transformer.ln_f.weight": "model.norm.weight",
472469
"lm_head.weight": "lm_head.weight",
473470
}
471+
if config.mlp_class_name == "LLaMAMoE":
472+
weight_map.update(
473+
{
474+
"transformer.h.{}.mlp.gate.weight": "model.layers.{}.mlp.gate.weight",
475+
"transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.mlp.experts.{}.gate_proj.weight",
476+
"transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.mlp.experts.{}.up_proj.weight",
477+
"transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.mlp.experts.{}.down_proj.weight",
478+
}
479+
)
480+
elif config.mlp_class_name == "LLaMAMLP":
481+
weight_map.update(
482+
{
483+
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
484+
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
485+
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
486+
}
487+
)
488+
else:
489+
raise NotImplementedError
474490

475491
for from_name, param in lit_weights.items():
476492
if from_name == "lm_head.weight" and untie_weights:

tests/test_model.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM
3333
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
3434
from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
35+
from transformers.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM
3536

3637
import litgpt.config as config_module
3738
from litgpt import GPT, Config
@@ -1139,6 +1140,72 @@ def test_against_original_qwen_3(model_name, device, dtype):
11391140
torch.testing.assert_close(ours_y, theirs_y)
11401141

11411142

1143+
@torch.inference_mode()
1144+
@pytest.mark.parametrize("model_name", ["Qwen3-30B-A3B", "Qwen3-235B-A22B"])
1145+
@pytest.mark.parametrize(
1146+
("device", "dtype"),
1147+
[
1148+
(torch.device("cpu"), torch.float32),
1149+
pytest.param(
1150+
torch.device("cuda"),
1151+
torch.float16,
1152+
marks=[
1153+
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
1154+
# is slightly different
1155+
pytest.mark.xfail(raises=AssertionError, strict=False),
1156+
_RunIf(min_cuda_gpus=1),
1157+
],
1158+
),
1159+
],
1160+
)
1161+
def test_against_original_qwen_3_moe(model_name, device, dtype):
1162+
torch.set_default_dtype(dtype)
1163+
1164+
T = 20
1165+
ours_config = Config.from_name(
1166+
model_name,
1167+
block_size=T,
1168+
n_layer=2,
1169+
n_head=16,
1170+
n_embd=32,
1171+
intermediate_size=86,
1172+
moe_intermediate_size=20,
1173+
n_expert=4,
1174+
n_expert_per_token=2,
1175+
)
1176+
theirs_config = Qwen3MoeConfig(
1177+
vocab_size=ours_config.padded_vocab_size,
1178+
hidden_size=ours_config.n_embd,
1179+
head_dim=ours_config.head_size,
1180+
num_attention_heads=ours_config.n_head,
1181+
num_hidden_layers=ours_config.n_layer,
1182+
intermediate_size=ours_config.intermediate_size,
1183+
moe_intermediate_size=ours_config.moe_intermediate_size,
1184+
max_position_embeddings=ours_config.block_size,
1185+
rms_norm_eps=ours_config.norm_eps,
1186+
num_key_value_heads=ours_config.n_query_groups,
1187+
rope_theta=ours_config.rope_base,
1188+
tie_word_embeddings=False,
1189+
num_experts=ours_config.n_expert,
1190+
num_experts_per_tok=ours_config.n_expert_per_token,
1191+
norm_topk_prob=True,
1192+
)
1193+
1194+
theirs_model = Qwen3MoeForCausalLM(theirs_config).to(device)
1195+
theirs_state_dict = theirs_model.state_dict()
1196+
state_dict = {}
1197+
copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict)
1198+
ours_model = GPT(ours_config).to(device)
1199+
ours_model.load_state_dict(state_dict)
1200+
1201+
# test end to end
1202+
x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)
1203+
assert x.size(1) == T
1204+
ours_y = ours_model(x)
1205+
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
1206+
torch.testing.assert_close(ours_y, theirs_y)
1207+
1208+
11421209
@torch.inference_mode()
11431210
@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b"))
11441211
@pytest.mark.parametrize(

tutorials/download_model_weights.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
4949
| QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) |
5050
| QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
5151
| Qwen3 | 0.6B, 1.7B, 4B, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
52+
| Qwen3 MoE | 30B, 235B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
5253
| R1 Distll Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) |
5354
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
5455
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |

0 commit comments

Comments
 (0)