@@ -794,16 +794,19 @@ def load_nvfp4_weights(self, weights: Dict):
794794 gate_up_bias = module_weights .get ('gate_up_proj_bias' , None )
795795 down_bias = module_weights .get ('down_proj_bias' , None )
796796
797- # Optional deinterleave for checkpoints that interleave gate/up
798- if gate_up is not None and gate_up .dim () == 3 :
799- try :
800- g , u = gate_up [:, :, ::2 ], gate_up [:, :, 1 ::2 ]
801- gate_up = torch .cat ([g , u ], dim = - 1 )
802- if gate_up_bias is not None :
803- gb , ub = gate_up_bias [:, ::2 ], gate_up_bias [:, 1 ::2 ]
804- gate_up_bias = torch .cat ([gb , ub ], dim = - 1 )
805- except Exception :
806- pass
797+ def deinterleave (tensor ):
798+ g , u = tensor [..., ::2 ], tensor [..., 1 ::2 ]
799+ return torch .cat ([g , u ], dim = - 1 )
800+
801+ print ("up projection shape before deinterleave:" , gate_up .shape )
802+ gate_up = deinterleave (gate_up )
803+ print ("up projection shape after deinterleave:" , gate_up .shape )
804+
805+ print ("up projection bias shape before deinterleave:" ,
806+ gate_up_bias .shape )
807+ gate_up_bias = deinterleave (gate_up_bias )
808+ print ("up projection bias shape after deinterleave:" ,
809+ gate_up_bias .shape )
807810
808811 # Only fp32 bias is supported for NVFP4 MoE.
809812 if gate_up_bias .dtype != torch .float32 :
@@ -832,6 +835,13 @@ def load_nvfp4_weights(self, weights: Dict):
832835 # Per-expert block scales (transpose to expected layout)
833836 if 'gate_up_proj_weight_scale' in module_weights :
834837 gu_ws = module_weights ['gate_up_proj_weight_scale' ]
838+ print (
839+ "up projection weight scale shape before deinterleave:" ,
840+ gu_ws .shape )
841+ gu_ws = deinterleave (gu_ws )
842+ print (
843+ "up projection weight scale shape after deinterleave:" ,
844+ gu_ws .shape )
835845 moe_weights ['gate_up_proj_weight_scale' ] = [
836846 gu_ws [i , :, :] for i in range (num_expert )
837847 ]
0 commit comments