Skip to content

Commit 51be8cf

Browse files
committed
Weights loading part
Signed-off-by: Dongfeng Yu <[email protected]>
1 parent ea250c8 commit 51be8cf

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,72 @@ def load_hf_weights(self, weights: Dict):
657657
module_weights = filter_weights(name, weights)
658658

659659
if isinstance(module, MoE):
660+
# Fast-path: NVFP4 HF ckpt for fused gate_up MoE
661+
if getattr(module, "quant_config", None) is not None and \
662+
module.quant_config.quant_mode.has_nvfp4():
663+
gate_up = module_weights.get('gate_up_proj', None)
664+
down = module_weights.get('down_proj', None)
665+
gate_up_bias = module_weights.get('gate_up_proj_bias', None)
666+
down_bias = module_weights.get('down_proj_bias', None)
667+
668+
# Optional deinterleave for checkpoints that interleave gate/up
669+
if gate_up is not None and gate_up.dim() == 3:
670+
try:
671+
g, u = gate_up[:, :, ::2], gate_up[:, :, 1::2]
672+
gate_up = torch.cat([g, u], dim=-1)
673+
if gate_up_bias is not None:
674+
gb, ub = gate_up_bias[:, ::
675+
2], gate_up_bias[:, 1::2]
676+
gate_up_bias = torch.cat([gb, ub], dim=-1)
677+
except Exception:
678+
pass
679+
680+
moe_weights = {}
681+
if gate_up is not None:
682+
moe_weights['gate_up_proj'] = [
683+
gate_up[i, :, :].transpose(0, 1)
684+
for i in range(num_expert)
685+
]
686+
if down is not None:
687+
moe_weights['down_proj'] = [
688+
down[i, :, :].transpose(0, 1)
689+
for i in range(num_expert)
690+
]
691+
if gate_up_bias is not None:
692+
moe_weights['gate_up_proj.bias'] = [
693+
gate_up_bias[i, :] for i in range(num_expert)
694+
]
695+
if down_bias is not None:
696+
moe_weights['down_proj.bias'] = [
697+
down_bias[i, :] for i in range(num_expert)
698+
]
699+
700+
# Per-expert block scales (transpose to expected layout)
701+
if 'gate_up_proj_weight_scale' in module_weights:
702+
gu_ws = module_weights['gate_up_proj_weight_scale']
703+
moe_weights['gate_up_proj_weight_scale'] = [
704+
gu_ws[i, :, :].transpose(0, 1)
705+
for i in range(num_expert)
706+
]
707+
if 'down_proj_weight_scale' in module_weights:
708+
dp_ws = module_weights['down_proj_weight_scale']
709+
moe_weights['down_proj_weight_scale'] = [
710+
dp_ws[i, :, :].transpose(0, 1)
711+
for i in range(num_expert)
712+
]
713+
714+
# Module-level globals for NVFP4 loaders
715+
for src_key in [
716+
'gate_up_proj_weight_scale_2',
717+
'down_proj_weight_scale_2',
718+
'gate_up_proj_input_scale',
719+
'down_proj_input_scale',
720+
]:
721+
if src_key in module_weights:
722+
moe_weights[src_key] = module_weights[src_key]
723+
724+
module.load_weights(weights=[moe_weights])
725+
continue
660726
try:
661727
# For BF16 ckpt.
662728
# Deinterleave for gate and up.

0 commit comments

Comments
 (0)