@@ -672,13 +672,17 @@ def __init__(
672672 # Disable fusion for small models due to accuracy issues
673673 self .enable_fusion &= config .hidden_size > 4096
674674
675- use_fused_gemm_allreduce = True
676- use_fused_gemm_allreduce &= (not mpi_disabled ())
677- use_fused_gemm_allreduce &= (self .mapping .tp_size > 1 )
678- use_fused_gemm_allreduce &= (config .torch_dtype
679- in (torch .float16 , torch .bfloat16 ))
680- use_fused_gemm_allreduce &= (self .is_nvfp4 is not None
681- and self .is_nvfp4 )
675+ mpi_enabled = not mpi_disabled ()
676+ dtype_supported = config .torch_dtype in (torch .float16 , torch .bfloat16 )
677+ tp_valid = self .mapping .tp_size > 1
678+ quant_valid = self .is_nvfp4 is not None and self .is_nvfp4
679+ use_fused_gemm_allreduce = all (
680+ [mpi_enabled , dtype_supported , tp_valid , quant_valid ])
681+
682+ def check_in_out_features (in_features , out_features ):
683+ in_feature_valid = in_features % 128 == 0 and in_features >= 1024
684+ out_feature_valid = out_features % 64 == 0 and out_features >= 1024
685+ return all ([in_feature_valid , out_feature_valid ])
682686
683687 num_heads = config .num_attention_heads
684688 head_dim = getattr (config , 'head_dim' , None )
@@ -687,21 +691,22 @@ def __init__(
687691
688692 in_features = num_heads * head_dim
689693 out_features = config .hidden_size
690- in_features_div_by = 128
691- attn_fused_gemm_allreduce = use_fused_gemm_allreduce and in_features % in_features_div_by == 0 and in_features >= 1024
692- attn_fused_gemm_allreduce &= (out_features % 64 == 0
693- and out_features >= 1024 )
694+ in_out_features_valid = check_in_out_features (in_features , out_features )
694695
696+ attn_fused_gemm_allreduce = all (
697+ [use_fused_gemm_allreduce , in_out_features_valid ])
695698 self .PRE_MLP_FUSION = not attn_fused_gemm_allreduce and self .mapping .has_tp (
696699 ) and not self .enable_attention_dp and self .enable_fusion
697700
698701 in_features = config .intermediate_size
699702 out_features = config .hidden_size
700- in_features_div_by = 128 * self .mapping .tp_size
701- mlp_fused_gemm_allreduce = use_fused_gemm_allreduce and in_features % in_features_div_by == 0 and in_features >= 1024
702- mlp_fused_gemm_allreduce &= (out_features % 64 == 0
703- and out_features >= 1024 )
704-
703+ in_features_aligned_with_tp = in_features % self .mapping .tp_size == 0
704+ in_out_features_valid = check_in_out_features (
705+ in_features // self .mapping .tp_size , out_features )
706+ mlp_fused_gemm_allreduce = all ([
707+ use_fused_gemm_allreduce , in_features_aligned_with_tp ,
708+ in_out_features_valid
709+ ])
705710 self .POST_MLP_FUSION = not mlp_fused_gemm_allreduce and self .mapping .has_tp (
706711 ) and self .enable_fusion
707712
0 commit comments