Skip to content

Commit b6d74e2

Browse files
Fix BF16 CUDA version of OpenAI's gpt-oss (microsoft#1706)
### Description This PR fixes a bug with generating the BF16 CUDA version of OpenAI's `gpt-oss` models. ### Motivation and Context The shape of the MLP 1 weights in the MoE layer is different between the OpenAI and Hugging Face implementations. OpenAI's implementation has the following shape: `(32, 5760, 2880)`. Hugging Face's implementation has the following shape: `(32, 2880, 5760)`. The [original PR changes](microsoft#1678) were made with the OpenAI implementation so a transpose is inserted to fix this gap.
1 parent ad2a6ed commit b6d74e2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/python/py/models/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4149,8 +4149,8 @@ def make_moe_fused(self, layer_id, mlp, root_input):
41494149

41504150
if op_type == "MoE":
41514151
# Save non-quantized MoE weights as initializers
4152-
self.make_initializer(mlp.experts.gate_up_proj.view(self.moe_attrs["num_experts"], -1, self.hidden_size), gate_up_proj_weight, to=self.io_dtype)
4153-
self.make_initializer(mlp.experts.down_proj.view(self.moe_attrs["num_experts"], self.hidden_size, self.intermediate_size), down_proj_weight, to=self.io_dtype)
4152+
self.make_initializer(mlp.experts.gate_up_proj.transpose(-1, -2).view(self.moe_attrs["num_experts"], -1, self.hidden_size), gate_up_proj_weight, to=self.io_dtype)
4153+
self.make_initializer(mlp.experts.down_proj.transpose(-1, -2).view(self.moe_attrs["num_experts"], self.hidden_size, self.intermediate_size), down_proj_weight, to=self.io_dtype)
41544154
else:
41554155
# Create and save quantized MoE weights as initializers
41564156
gate_up_proj_qweight_list, gate_up_proj_scales_list = [], []

0 commit comments

Comments
 (0)