@@ -657,6 +657,72 @@ def load_hf_weights(self, weights: Dict):
657657 module_weights = filter_weights (name , weights )
658658
659659 if isinstance (module , MoE ):
660+ # Fast-path: NVFP4 HF ckpt for fused gate_up MoE
661+ if getattr (module , "quant_config" , None ) is not None and \
662+ module .quant_config .quant_mode .has_nvfp4 ():
663+ gate_up = module_weights .get ('gate_up_proj' , None )
664+ down = module_weights .get ('down_proj' , None )
665+ gate_up_bias = module_weights .get ('gate_up_proj_bias' , None )
666+ down_bias = module_weights .get ('down_proj_bias' , None )
667+
668+ # Optional deinterleave for checkpoints that interleave gate/up
669+ if gate_up is not None and gate_up .dim () == 3 :
670+ try :
671+ g , u = gate_up [:, :, ::2 ], gate_up [:, :, 1 ::2 ]
672+ gate_up = torch .cat ([g , u ], dim = - 1 )
673+ if gate_up_bias is not None :
674+ gb , ub = gate_up_bias [:, ::
675+ 2 ], gate_up_bias [:, 1 ::2 ]
676+ gate_up_bias = torch .cat ([gb , ub ], dim = - 1 )
677+ except Exception :
678+ pass
679+
680+ moe_weights = {}
681+ if gate_up is not None :
682+ moe_weights ['gate_up_proj' ] = [
683+ gate_up [i , :, :].transpose (0 , 1 )
684+ for i in range (num_expert )
685+ ]
686+ if down is not None :
687+ moe_weights ['down_proj' ] = [
688+ down [i , :, :].transpose (0 , 1 )
689+ for i in range (num_expert )
690+ ]
691+ if gate_up_bias is not None :
692+ moe_weights ['gate_up_proj.bias' ] = [
693+ gate_up_bias [i , :] for i in range (num_expert )
694+ ]
695+ if down_bias is not None :
696+ moe_weights ['down_proj.bias' ] = [
697+ down_bias [i , :] for i in range (num_expert )
698+ ]
699+
700+ # Per-expert block scales (transpose to expected layout)
701+ if 'gate_up_proj_weight_scale' in module_weights :
702+ gu_ws = module_weights ['gate_up_proj_weight_scale' ]
703+ moe_weights ['gate_up_proj_weight_scale' ] = [
704+ gu_ws [i , :, :].transpose (0 , 1 )
705+ for i in range (num_expert )
706+ ]
707+ if 'down_proj_weight_scale' in module_weights :
708+ dp_ws = module_weights ['down_proj_weight_scale' ]
709+ moe_weights ['down_proj_weight_scale' ] = [
710+ dp_ws [i , :, :].transpose (0 , 1 )
711+ for i in range (num_expert )
712+ ]
713+
714+ # Module-level globals for NVFP4 loaders
715+ for src_key in [
716+ 'gate_up_proj_weight_scale_2' ,
717+ 'down_proj_weight_scale_2' ,
718+ 'gate_up_proj_input_scale' ,
719+ 'down_proj_input_scale' ,
720+ ]:
721+ if src_key in module_weights :
722+ moe_weights [src_key ] = module_weights [src_key ]
723+
724+ module .load_weights (weights = [moe_weights ])
725+ continue
660726 try :
661727 # For BF16 ckpt.
662728 # Deinterleave for gate and up.
0 commit comments