|
32 | 32 | from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM
|
33 | 33 | from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
34 | 34 | from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
|
| 35 | +from transformers.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM |
35 | 36 |
|
36 | 37 | import litgpt.config as config_module
|
37 | 38 | from litgpt import GPT, Config
|
@@ -1139,6 +1140,72 @@ def test_against_original_qwen_3(model_name, device, dtype):
|
1139 | 1140 | torch.testing.assert_close(ours_y, theirs_y)
|
1140 | 1141 |
|
1141 | 1142 |
|
| 1143 | +@torch.inference_mode() |
| 1144 | +@pytest.mark.parametrize("model_name", ["Qwen3-30B-A3B", "Qwen3-235B-A22B"]) |
| 1145 | +@pytest.mark.parametrize( |
| 1146 | + ("device", "dtype"), |
| 1147 | + [ |
| 1148 | + (torch.device("cpu"), torch.float32), |
| 1149 | + pytest.param( |
| 1150 | + torch.device("cuda"), |
| 1151 | + torch.float16, |
| 1152 | + marks=[ |
| 1153 | + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input |
| 1154 | + # is slightly different |
| 1155 | + pytest.mark.xfail(raises=AssertionError, strict=False), |
| 1156 | + _RunIf(min_cuda_gpus=1), |
| 1157 | + ], |
| 1158 | + ), |
| 1159 | + ], |
| 1160 | +) |
| 1161 | +def test_against_original_qwen_3_moe(model_name, device, dtype): |
| 1162 | + torch.set_default_dtype(dtype) |
| 1163 | + |
| 1164 | + T = 20 |
| 1165 | + ours_config = Config.from_name( |
| 1166 | + model_name, |
| 1167 | + block_size=T, |
| 1168 | + n_layer=2, |
| 1169 | + n_head=16, |
| 1170 | + n_embd=32, |
| 1171 | + intermediate_size=86, |
| 1172 | + moe_intermediate_size=20, |
| 1173 | + n_expert=4, |
| 1174 | + n_expert_per_token=2, |
| 1175 | + ) |
| 1176 | + theirs_config = Qwen3MoeConfig( |
| 1177 | + vocab_size=ours_config.padded_vocab_size, |
| 1178 | + hidden_size=ours_config.n_embd, |
| 1179 | + head_dim=ours_config.head_size, |
| 1180 | + num_attention_heads=ours_config.n_head, |
| 1181 | + num_hidden_layers=ours_config.n_layer, |
| 1182 | + intermediate_size=ours_config.intermediate_size, |
| 1183 | + moe_intermediate_size=ours_config.moe_intermediate_size, |
| 1184 | + max_position_embeddings=ours_config.block_size, |
| 1185 | + rms_norm_eps=ours_config.norm_eps, |
| 1186 | + num_key_value_heads=ours_config.n_query_groups, |
| 1187 | + rope_theta=ours_config.rope_base, |
| 1188 | + tie_word_embeddings=False, |
| 1189 | + num_experts=ours_config.n_expert, |
| 1190 | + num_experts_per_tok=ours_config.n_expert_per_token, |
| 1191 | + norm_topk_prob=True, |
| 1192 | + ) |
| 1193 | + |
| 1194 | + theirs_model = Qwen3MoeForCausalLM(theirs_config).to(device) |
| 1195 | + theirs_state_dict = theirs_model.state_dict() |
| 1196 | + state_dict = {} |
| 1197 | + copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict) |
| 1198 | + ours_model = GPT(ours_config).to(device) |
| 1199 | + ours_model.load_state_dict(state_dict) |
| 1200 | + |
| 1201 | + # test end to end |
| 1202 | + x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) |
| 1203 | + assert x.size(1) == T |
| 1204 | + ours_y = ours_model(x) |
| 1205 | + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float |
| 1206 | + torch.testing.assert_close(ours_y, theirs_y) |
| 1207 | + |
| 1208 | + |
1142 | 1209 | @torch.inference_mode()
|
1143 | 1210 | @pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b"))
|
1144 | 1211 | @pytest.mark.parametrize(
|
|
0 commit comments