@@ -2039,23 +2039,6 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
20392039 return super ().load_expert_w2_weight_scale_nvfp4 (
20402040 module , w2_weight_scale , dst_w2_weight_scale , 32 )
20412041
2042- def load_all_fp4_weight_scales_and_alphas (
2043- self , module : torch .nn .Module , weights : Dict ,
2044- load_expert_ids : List [int ], dst_w3_w1_weight_scale : torch .Tensor ,
2045- dst_w2_weight_scale : torch .Tensor , dst_fc31_alpha : torch .Tensor ,
2046- dst_fc2_alpha : torch .Tensor ):
2047- super ().load_all_fp4_weight_scales_and_alphas (
2048- module , weights , load_expert_ids , dst_w3_w1_weight_scale ,
2049- dst_w2_weight_scale , dst_fc31_alpha , dst_fc2_alpha )
2050- # The kernel we use will convert nvfp4 to e4m3 before matmul,
2051- # so the range of the scale factor can only be [0,448/6].
2052- dst_w3_w1_weight_scale .copy_ ((dst_w3_w1_weight_scale .to (torch .float32 ) /
2053- 6.0 ).to (torch .float8_e4m3fn ))
2054- dst_w2_weight_scale .copy_ ((dst_w2_weight_scale .to (torch .float32 ) /
2055- 6.0 ).to (torch .float8_e4m3fn ))
2056- dst_fc31_alpha .copy_ (dst_fc31_alpha * 6.0 )
2057- dst_fc2_alpha .copy_ (dst_fc2_alpha * 6.0 )
2058-
20592042
20602043def _get_weight_alignment (weight_alignment , scaling_vector_size , tp_size ,
20612044 shard_dim_size ):
0 commit comments