1414# limitations under the License.
1515import copy
1616import re
17+ from collections import defaultdict
1718from warnings import warn
1819
1920import torch
4142from megatron .core .parallel_state import (
4243 get_expert_model_parallel_group ,
4344 get_expert_tensor_parallel_group ,
45+ get_expert_tensor_parallel_rank ,
4446 initialize_model_parallel ,
4547 is_pipeline_first_stage ,
4648 is_pipeline_last_stage ,
@@ -190,7 +192,7 @@ def squared_relu(x):
190192 pipeline_model_parallel_size = pipeline_model_parallel_size ,
191193 expert_model_parallel_size = expert_model_parallel_size ,
192194 expert_tensor_parallel_size = expert_tensor_parallel_size ,
193- sequence_parallel = expert_model_parallel_size > 1 ,
195+ sequence_parallel = False ,
194196 moe_grouped_gemm = moe_grouped_gemm ,
195197 num_layers = num_layers ,
196198 num_layers_in_first_pipeline_stage = num_layers_in_first_pipeline_stage ,
@@ -565,8 +567,7 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
565567 # Check for both TEGrouped and sequential MoE patterns
566568 if "local_experts" in name or ("experts" in name and "linear_fc" in name ):
567569 # Convert to scalar only if tensor has a single element
568- amax_val = module .amax .detach ().clone ().cpu ()
569- expert_amax_values [name ] = amax_val
570+ expert_amax_values [name ] = module .amax .detach ().clone ().cpu ()
570571
571572 # Early return if no expert quantizers found
572573 assert expert_amax_values , "No expert quantizers found"
@@ -577,19 +578,16 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
577578 torch .distributed .all_gather_object (all_amax_values , expert_amax_values )
578579
579580 # Group quantizers by type (ignoring specific expert indices) and check sync
580- expert_quantizers = {}
581+ expert_quantizers = defaultdict ( dict )
581582 for rank_idx , rank_amax in enumerate (all_amax_values ):
582583 for name , amax_val in rank_amax .items ():
583584 # Create quantizer type key by normalizing the name
584- if "local_experts" in name :
585- # sequential MoE: replace expert index with wildcard
586- quantizer_type = re .sub (r"local_experts\.\d+" , "local_experts.*" , name )
587- else :
588- # TEGrouped MoE: use the name as-is since experts are grouped
589- quantizer_type = name
590-
591- if quantizer_type not in expert_quantizers :
592- expert_quantizers [quantizer_type ] = {}
585+ quantizer_type = (
586+ re .sub (r"local_experts\.\d+" , "local_experts.*" , name )
587+ if "local_experts" in name
588+ else name
589+ )
590+
593591 if (
594592 quantizer_type in expert_quantizers
595593 and rank_idx in expert_quantizers [quantizer_type ]
@@ -608,21 +606,52 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
608606 )
609607 expert_quantizers [quantizer_type ][rank_idx ] = amax_val
610608
611- # Check synchronization - fail fast on first inconsistency
609+ rank_info = {
610+ "global_rank" : torch .distributed .get_rank (),
611+ "etp_rank" : get_expert_tensor_parallel_rank (),
612+ }
613+
614+ all_rank_info = [None ] * world_size
615+ torch .distributed .all_gather_object (all_rank_info , rank_info )
616+
617+ # Group ranks by ETP rank for fc1 (ColumnParallel: same output channels should match)
618+ etp_groups = defaultdict (list )
619+ for info in all_rank_info :
620+ etp_groups [info ["etp_rank" ] if info ["etp_rank" ] else 0 ].append (info ["global_rank" ])
621+
612622 for quantizer_type , rank_values in expert_quantizers .items ():
613- if len (rank_values ) > 1 : # Only check if we have multiple ranks
614- values = list (rank_values .values ())
615- # Handle both scalar and tensor comparisons
616- first_val = values [0 ]
617- if isinstance (first_val , torch .Tensor ):
618- # For tensors, check if all values are close to the first one
619- for val in values [1 :]:
620- if not torch .allclose (first_val , val , rtol = 1e-6 , atol = 1e-6 ):
621- return False , quantizer_type , rank_values
622- else :
623- # For scalars, use numeric comparison
624- max_diff = max (values ) - min (values )
625- if max_diff > 1e-6 : # Allow for small floating point differences
626- return False , quantizer_type , rank_values
623+ # Determine which ranks should have same amax
624+ # Find which rank should have same amax
625+ #
626+ # fc1: ColumnParallel: X @ [A_1, A_2] (weights split along Cout)
627+ # so amax should be the same across same ETP rank
628+ # if EP is 2, ETP is 2, we have 4 ranks, EP1, ETP1: 0, EP1, ETP2: 1, EP2, ETP1: 2, EP2, ETP2: 3
629+ # so we need to compare amax across same ETP rank [0, 2] [1, 3]
630+ #
631+ # fc2: RowParallel: [X_1, X_2] @ [A_1
632+ # A_2] (weights split along Cin)
633+ # amax should be the same across all ranks
634+
635+ rank_groups = (
636+ list (etp_groups .values ())
637+ if "linear_fc1" in quantizer_type
638+ else [list (range (world_size ))]
639+ )
640+ # Check each group independently
641+ for group in rank_groups :
642+ group_values = [rank_values [r ] for r in group if r in rank_values ]
643+ if len (group_values ) > 1 :
644+ # All values in this group should be identical
645+ first_val = group_values [0 ]
646+ for val in group_values [1 :]:
647+ if isinstance (first_val , torch .Tensor ):
648+ if not torch .allclose (first_val , val , rtol = 1e-6 , atol = 1e-6 ):
649+ group_rank_values = {
650+ r : rank_values [r ] for r in group if r in rank_values
651+ }
652+ return False , f"{ quantizer_type } (group { group } )" , group_rank_values
653+ elif abs (first_val - val ) > 1e-6 :
654+ group_rank_values = {r : rank_values [r ] for r in group if r in rank_values }
655+ return False , f"{ quantizer_type } (group { group } )" , group_rank_values
627656
628657 return True , None , None
0 commit comments