Skip to content

Commit 776b733

Browse files
committed
fix converting checkpoint and tp
1 parent 635db73 commit 776b733

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

mixtral-moe/scripts/convert_hf_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def convert_hf_checkpoint(
7979
elif "w1" in key or "w3" in key:
8080
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).contiguous()
8181
elif "w2" in key:
82-
final_result[key] = final_result[key].reshape(config.num_experts, config.dim, config.intermediate_size).contiguous()
82+
final_result[key] = final_result[key].reshape(config.num_experts, config.intermediate_size, config.dim).permute(0, 2, 1).contiguous()
8383
elif "gate" in key:
8484
final_result[key] = final_result[key].contiguous()
8585

mixtral-moe/tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _apply_tp_moe_ffn(mlp: MOEFeedForward) -> None:
104104
if hasattr(mlp.cond_ffn, "scales1"):
105105
mlp.cond_ffn.scales1 = nn.Parameter(shard(mlp.cond_ffn.scales1, 1), requires_grad=False)
106106
mlp.cond_ffn.scales3 = nn.Parameter(shard(mlp.cond_ffn.scales3, 1), requires_grad=False)
107-
mlp.cond_ffn.scales2 = nn.Parameter(shard(mlp.cond_ffn.scales2, 1), requires_grad=False)
107+
mlp.cond_ffn.scales2 = nn.Parameter(mlp.cond_ffn.scales2, requires_grad=False)
108108

109109
world_size = _get_world_size()
110110
mlp.cond_ffn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(

0 commit comments

Comments
 (0)