@@ -515,20 +515,21 @@ def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_mo
515515
516516 # Map grouped weights to sequential weights
517517 weight_mapping = {}
518- sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight "
518+ sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}"
519519 for key , value in te_grouped_state .items ():
520- if "experts.linear_fc" in key and "weight" in key :
520+ if "experts.linear_fc" in key and any ( param in key for param in ( "weight" , "bias" )) :
521521 # Extract expert index from grouped weight name
522522 # Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ
523523 parts = key .split ("." )
524524 layer_idx = parts [2 ] # X
525525 fc_idx = parts [5 ] # Y (linear_fc1 or linear_fc2)
526- weight_idx = parts [6 ] # Z ( weight0, weight1, etc.)
527-
528- # Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight
529- expert_idx = weight_idx . replace ( "weight" , "" )
526+ param_idx = parts [6 ] # weight0 / bias0 / etc.
527+ match = re . search ( r"\d+" , param_idx )
528+ expert_idx = match . group ( 0 ) if match else "0" # Z for expert index
529+ # Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ
530530 sequential_key = sequential_key_template .format (layer_idx , expert_idx , fc_idx [- 1 ])
531- weight_mapping [sequential_key ] = value
531+ param_name = "weight" if "weight" in param_idx else "bias"
532+ weight_mapping [f"{ sequential_key } .{ param_name } " ] = value
532533 elif isinstance (value , torch .Tensor ):
533534 weight_mapping [key ] = value
534535
@@ -540,7 +541,7 @@ def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_mo
540541 sequential_moe_model .load_state_dict (sequential_state )
541542
542543
543- def compare_amax_sync_across_expert_parallel (model ):
544+ def compare_amax_sync_across_expert_parallel (model , compare_across_experts = True ):
544545 """
545546 Test if amax values are synchronized across expert parallel groups.
546547
@@ -591,11 +592,12 @@ def compare_amax_sync_across_expert_parallel(model):
591592 quantizer_type in expert_quantizers
592593 and rank_idx in expert_quantizers [quantizer_type ]
593594 ):
594- # compare expert value across expert for sequential MoE
595- assert expert_quantizers [quantizer_type ][rank_idx ] == amax_val , (
596- f"{ rank_idx } , { quantizer_type } , expert_quantizers[quantizer_type][rank_idx]: "
597- f"{ expert_quantizers [quantizer_type ][rank_idx ]} , amax_val: { amax_val } "
598- )
595+ if compare_across_experts :
596+ # compare expert value across expert for sequential MoE
597+ assert expert_quantizers [quantizer_type ][rank_idx ] == amax_val , (
598+ f"{ rank_idx } , { quantizer_type } , expert_quantizers[quantizer_type][rank_idx]: "
599+ f"{ expert_quantizers [quantizer_type ][rank_idx ]} , amax_val: { amax_val } "
600+ )
599601 expert_quantizers [quantizer_type ][rank_idx ] = amax_val
600602
601603 # Check synchronization - fail fast on first inconsistency
0 commit comments