Skip to content

Commit fa4cc3e

Browse files
committed
Revert "remove wrong interleave"
This reverts commit b143db2.
1 parent b143db2 commit fa4cc3e

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,17 @@ def load_nvfp4_weights(self, weights: Dict):
794794
gate_up_bias = module_weights.get('gate_up_proj_bias', None)
795795
down_bias = module_weights.get('down_proj_bias', None)
796796

797+
# Optional deinterleave for checkpoints that interleave gate/up
798+
if gate_up is not None and gate_up.dim() == 3:
799+
try:
800+
g, u = gate_up[:, :, ::2], gate_up[:, :, 1::2]
801+
gate_up = torch.cat([g, u], dim=-1)
802+
if gate_up_bias is not None:
803+
gb, ub = gate_up_bias[:, ::2], gate_up_bias[:, 1::2]
804+
gate_up_bias = torch.cat([gb, ub], dim=-1)
805+
except Exception:
806+
pass
807+
797808
# Only fp32 bias is supported for NVFP4 MoE.
798809
if gate_up_bias.dtype != torch.float32:
799810
gate_up_bias = gate_up_bias.to(torch.float32)

0 commit comments

Comments
 (0)