1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import copy
16+ import re
1617from warnings import warn
1718
1819import torch
5758 save_sharded_modelopt_state ,
5859)
5960from modelopt .torch .utils import to_empty_if_meta_device
60- from modelopt .torch .utils .distributed import DistributedProcessGroup
6161
6262try :
6363 from megatron .core .extensions .transformer_engine import TENorm
@@ -143,7 +143,7 @@ def get_mcore_gpt_model(
143143 tensor_model_parallel_size : int = 1 ,
144144 pipeline_model_parallel_size : int = 1 ,
145145 expert_model_parallel_size : int = 1 ,
146- expert_tensor_parallel_size : int = 1 ,
146+ expert_tensor_parallel_size : int | None = None ,
147147 initialize_megatron : bool = False ,
148148 * ,
149149 num_layers : int = 2 ,
@@ -497,61 +497,6 @@ def convert_maybe_fp8(v):
497497 )
498498
499499
500- def compare_model_outputs (grouped_model , non_grouped_model , forward_fn , tolerance = 1e-6 ):
501- """Compare outputs of grouped and non-grouped models."""
502- # Set both models to eval mode
503- grouped_model .eval ()
504- non_grouped_model .eval ()
505-
506- with torch .no_grad ():
507- # Get outputs from both models
508- grouped_output = forward_fn (grouped_model )
509- non_grouped_output = forward_fn (non_grouped_model )
510-
511- # Compare outputs
512- if isinstance (grouped_output , tuple ):
513- grouped_output = grouped_output [0 ]
514- if isinstance (non_grouped_output , tuple ):
515- non_grouped_output = non_grouped_output [0 ]
516-
517- output_close = torch .allclose (
518- grouped_output , non_grouped_output , atol = tolerance , rtol = tolerance
519- )
520- return output_close
521-
522-
523- def sync_amax (model ):
524- amax_dict = {
525- "linear_fc1.input_quantizer" : {},
526- "linear_fc1.weight_quantizer" : {},
527- "linear_fc2.input_quantizer" : {},
528- "linear_fc2.weight_quantizer" : {},
529- }
530- for name , module in model .named_modules ():
531- if not isinstance (module , mtq .nn .TensorQuantizer ):
532- continue
533- if not hasattr (module , "_amax" ):
534- continue
535- if "local_experts" not in name :
536- continue
537- expert_name , local_expert_name = name .split ("local_experts" )
538- for key in amax_dict :
539- if key in local_expert_name :
540- amax_dict [key ][expert_name ] = max (amax_dict [key ].get (expert_name , 0 ), module .amax )
541-
542- for name , module in model .named_modules ():
543- if not isinstance (module , mtq .nn .TensorQuantizer ):
544- continue
545- if not hasattr (module , "_amax" ):
546- continue
547- if "local_experts" not in name :
548- continue
549- expert_name , local_expert_name = name .split ("local_experts" )
550- for key in amax_dict :
551- if key in local_expert_name :
552- module .amax = amax_dict [key ][expert_name ]
553-
554-
555500def copy_weights_from_grouped_to_non_grouped (grouped_model , non_grouped_model ):
556501 """Copy weights from grouped MoE model to non-grouped MoE model."""
557502 grouped_state = grouped_model .state_dict ()
@@ -625,8 +570,6 @@ def compare_amax_sync_across_expert_parallel(model):
625570 # Create quantizer type key by normalizing the name
626571 if "local_experts" in name :
627572 # Non-grouped MoE: replace expert index with wildcard
628- import re
629-
630573 quantizer_type = re .sub (r"local_experts\.\d+" , "local_experts.*" , name )
631574 else :
632575 # Grouped MoE: use the name as-is since experts are grouped
@@ -641,50 +584,7 @@ def compare_amax_sync_across_expert_parallel(model):
641584 if len (rank_values ) > 1 : # Only check if we have multiple ranks
642585 values = list (rank_values .values ())
643586 max_diff = max (values ) - min (values )
644-
645587 if max_diff > 1e-6 : # Allow for small floating point differences
646- return False
588+ return False , quantizer_type , rank_values
647589
648- return True
649-
650-
651- def disable_distributed_parallel_sync (model , expert_parallel_type : str = "tensor" ):
652- """Disable distributed parallel synchronization groups."""
653- module_parallel_groups = {}
654-
655- for name , module in model .named_modules ():
656- if isinstance (module , mtq .nn .QuantModule ):
657- # Store original groups
658- module_parallel_groups [name ] = {
659- "data_parallel_group" : module .parallel_state .data_parallel_group ,
660- "expert_tensor_parallel_group" : module .parallel_state .expert_tensor_parallel_group ,
661- "expert_model_parallel_group" : module .parallel_state .expert_model_parallel_group ,
662- }
663-
664- # Disable groups
665- module .parallel_state .data_parallel_group = DistributedProcessGroup (- 1 )
666-
667- if expert_parallel_type in ["tensor" , "both" ]:
668- module .parallel_state .expert_tensor_parallel_group = DistributedProcessGroup (- 1 )
669- if expert_parallel_type in ["model" , "both" ]:
670- module .parallel_state .expert_model_parallel_group = DistributedProcessGroup (- 1 )
671-
672- return module_parallel_groups
673-
674-
675- def enable_distributed_parallel_sync (
676- model , module_parallel_groups , expert_parallel_type : str = "tensor"
677- ):
678- """Re-enable distributed parallel synchronization groups."""
679- for name , module in model .named_modules ():
680- if isinstance (module , mtq .nn .QuantModule ) and name in module_parallel_groups :
681- groups = module_parallel_groups [name ]
682-
683- if expert_parallel_type in ["tensor" , "both" ]:
684- module .parallel_state .expert_tensor_parallel_group = groups [
685- "expert_tensor_parallel_group"
686- ]
687- if expert_parallel_type in ["model" , "both" ]:
688- module .parallel_state .expert_model_parallel_group = groups [
689- "expert_model_parallel_group"
690- ]
590+ return True , None , None
0 commit comments