Skip to content

Commit 3de4d5d

Browse files
committed
Use the correct layout
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent fa4cc3e commit 3de4d5d

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -794,16 +794,19 @@ def load_nvfp4_weights(self, weights: Dict):
794794
gate_up_bias = module_weights.get('gate_up_proj_bias', None)
795795
down_bias = module_weights.get('down_proj_bias', None)
796796

797-
# Optional deinterleave for checkpoints that interleave gate/up
798-
if gate_up is not None and gate_up.dim() == 3:
799-
try:
800-
g, u = gate_up[:, :, ::2], gate_up[:, :, 1::2]
801-
gate_up = torch.cat([g, u], dim=-1)
802-
if gate_up_bias is not None:
803-
gb, ub = gate_up_bias[:, ::2], gate_up_bias[:, 1::2]
804-
gate_up_bias = torch.cat([gb, ub], dim=-1)
805-
except Exception:
806-
pass
797+
def deinterleave(tensor):
798+
g, u = tensor[..., ::2], tensor[..., 1::2]
799+
return torch.cat([g, u], dim=-1)
800+
801+
print("up projection shape before deinterleave:", gate_up.shape)
802+
gate_up = deinterleave(gate_up)
803+
print("up projection shape after deinterleave:", gate_up.shape)
804+
805+
print("up projection bias shape before deinterleave:",
806+
gate_up_bias.shape)
807+
gate_up_bias = deinterleave(gate_up_bias)
808+
print("up projection bias shape after deinterleave:",
809+
gate_up_bias.shape)
807810

808811
# Only fp32 bias is supported for NVFP4 MoE.
809812
if gate_up_bias.dtype != torch.float32:
@@ -832,6 +835,13 @@ def load_nvfp4_weights(self, weights: Dict):
832835
# Per-expert block scales (transpose to expected layout)
833836
if 'gate_up_proj_weight_scale' in module_weights:
834837
gu_ws = module_weights['gate_up_proj_weight_scale']
838+
print(
839+
"up projection weight scale shape before deinterleave:",
840+
gu_ws.shape)
841+
gu_ws = deinterleave(gu_ws)
842+
print(
843+
"up projection weight scale shape after deinterleave:",
844+
gu_ws.shape)
835845
moe_weights['gate_up_proj_weight_scale'] = [
836846
gu_ws[i, :, :] for i in range(num_expert)
837847
]

0 commit comments

Comments
 (0)