5252
5353def sync_amax_across_sequential_mlp (model : torch .nn .Module ):
5454 """Sync amax across experts in a SequentialMLP."""
55- amax_dict = {
56- "linear_fc1.input_quantizer" : {},
57- "linear_fc1.weight_quantizer" : {},
58- "linear_fc2.input_quantizer" : {},
59- "linear_fc2.weight_quantizer" : {},
60- }
61- # gather amax values from SequentialMLP experts
62- for name , module in model .named_modules ():
55+ amax_dict = {}
56+
57+ def get_sequential_mlp_expert_names (name : str , module : torch .nn .Module ):
6358 if (
64- not isinstance (module , TensorQuantizer )
65- or not hasattr (module , "_amax" )
66- or " local_experts" not in name
59+ isinstance (module , TensorQuantizer )
60+ and hasattr (module , "_amax" )
61+ and ". local_experts." in name
6762 ):
68- continue
69- expert_name , local_expert_name = name .split ("local_experts" )
70- for key in amax_dict :
71- if key in local_expert_name :
72- amax_dict [key ][expert_name ] = max (amax_dict [key ].get (expert_name , 0 ), module .amax )
63+ expert_name , local_expert_name = name .split (".local_experts." )
64+ # extract quantizer name by removing local_expert number from the name
65+ local_expert_name = "." .join (local_expert_name .split ("." )[1 :])
66+ return expert_name , local_expert_name
67+ return None , None
68+
69+ # gather amax values from SequentialMLP experts
70+ for name , module in model .named_modules ():
71+ expert_name , local_expert_name = get_sequential_mlp_expert_names (name , module )
72+ if expert_name and local_expert_name :
73+ amax_dict [local_expert_name ] = amax_dict .get (local_expert_name , {})
74+ amax_dict [local_expert_name ][expert_name ] = max (
75+ amax_dict [local_expert_name ].get (expert_name , 0 ), module .amax
76+ )
7377
7478 # sync amax values across experts in SequentialMLP
7579 for name , module in model .named_modules ():
76- if (
77- not isinstance (module , TensorQuantizer )
78- or not hasattr (module , "_amax" )
79- or "local_experts" not in name
80- ):
81- continue
82- expert_name , local_expert_name = name .split ("local_experts" )
83- for key in amax_dict :
84- if key in local_expert_name :
85- module .amax = amax_dict [key ][expert_name ]
80+ expert_name , local_expert_name = get_sequential_mlp_expert_names (name , module )
81+ if expert_name and local_expert_name :
82+ module .amax = amax_dict [local_expert_name ][expert_name ]
8683
8784
8885CUSTOM_POST_CALIBRATION_PLUGINS .add (sync_amax_across_sequential_mlp )
@@ -523,6 +520,11 @@ def forward(self, input, *args, **kwargs):
523520# Register the public te.pytorch.GroupedLinear class
524521@QuantModuleRegistry .register ({te_grouped_linear .GroupedLinear : "te_GroupedLinear" })
525522class _QuantMegatronTEGroupedLinear (_MegatronParallelLinear ):
523+ _functionals_to_replace = [
524+ (te_grouped_linear ._GroupedLinear , "forward" ),
525+ (te_grouped_linear ._GroupedLinear , "apply" ),
526+ ]
527+
526528 def _setup (self ):
527529 # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
528530 # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
@@ -531,46 +533,17 @@ def _setup(self):
531533 # Memorize the original weight.dtype for modelopt_post_restore given that
532534 # the dtype can change later.
533535 super ()._setup ()
534- # Revert the weight to None after setup.
535- self .weight = None
536-
537- @property
538- def functionals_to_replace (self ):
539- original_forward = te_grouped_linear ._GroupedLinear .forward
540-
541- def te_grouped_quantized_linear_fn (ctx , inp , m_splits , * args ):
542- num_gemms = len (m_splits )
543- weights_and_biases = args [- 2 * num_gemms :]
544- weights , biases = weights_and_biases [:num_gemms ], weights_and_biases [num_gemms :]
545- quantized_inputs = self .input_quantizer (inp )
546- quantized_weights = [self .weight_quantizer (weight ) for weight in weights ]
547-
548- output = original_forward (
549- ctx ,
550- quantized_inputs ,
551- m_splits ,
552- * args [: - 2 * num_gemms ],
553- * quantized_weights ,
554- * biases ,
555- )
556- return self .output_quantizer (output )
557-
558- return [
559- (
560- te_grouped_linear ._GroupedLinear ,
561- "forward" ,
562- te_grouped_quantized_linear_fn ,
563- ),
564- ]
536+ # Remove self.weight after setup.
537+ delattr (self , "weight" )
565538
566539 def modelopt_post_restore (self , prefix : str = "" ):
567540 # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
568541 # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
569542 # self.weight0 to self.weight to run the quantizer states initialization.
570543 self .weight = self .weight0
571544 super ().modelopt_post_restore (prefix = prefix )
572- # Revert the weight to None after post_restore.
573- self . weight = None
545+ # Remove self. weight after post_restore.
546+ delattr ( self , "weight" )
574547
575548 def _load_from_state_dict (self , state_dict , prefix , * args , ** kwargs ):
576549 # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
@@ -585,10 +558,34 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
585558 return super ()._load_from_state_dict (filtered_state_dict , prefix , * args , ** kwargs )
586559
587560 def _process_quantizer_amax (self , k , v , quantizer_state_dict ):
588- if v .ndim == 4 :
589- quantizer_state_dict [k ] = v .squeeze (1 ).squeeze (- 1 )
590- else :
591- quantizer_state_dict [k ] = v .view (- 1 , 1 ) if v .numel () > 1 else v .view (- 1 )
561+ assert v .numel () == 1 , "TEGroupedLinear only supports per-tensor quantization"
562+ quantizer_state_dict [k ] = v .view (- 1 )
563+
564+ @staticmethod
565+ def te_grouped_quantized_linear_fn (package , func_name , self , * args ):
566+ idx = 1 if func_name == "_forward" else 0
567+ inp = args [idx ]
568+ num_gemms = len (args [idx + 1 ])
569+ weights_and_biases = args [- 2 * num_gemms :]
570+ weights , biases = weights_and_biases [:num_gemms ], weights_and_biases [num_gemms :]
571+ quantized_inputs = self .input_quantizer (inp )
572+ quantized_weights = [self .weight_quantizer (weight ) for weight in weights ]
573+
574+ output = getattr (package , func_name )(
575+ * (
576+ args [0 ],
577+ quantized_inputs ,
578+ )
579+ if func_name == "_forward"
580+ else (quantized_inputs ,),
581+ * args [idx + 1 : - 2 * num_gemms ],
582+ * quantized_weights ,
583+ * biases ,
584+ )
585+ return self .output_quantizer (output )
586+
587+ # Override the quantized linear function
588+ _quantized_linear_fn = te_grouped_quantized_linear_fn
592589
593590
594591@QuantModuleRegistry .register (
@@ -614,42 +611,36 @@ class _MegatronTEGroupedRowParallelLinear(
614611class _MegatronTEGroupedMLP (_MegatronMLP ):
615612 def _setup (self ):
616613 if not hasattr (self , "parallel_state" ) or self .parallel_state is None :
617- data_parallel_group = None
618- try :
619- data_parallel_group = get_data_parallel_group (with_context_parallel = True )
620- except AssertionError :
621- logger .warning (
622- "Context parallel group is not initialized, using data parallel group"
623- )
624- data_parallel_group = get_data_parallel_group ()
625-
626- try :
627- expert_tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group ()
628- except AssertionError :
629- expert_tensor_parallel_group = None
630614 self .parallel_state = ParallelState (
631- data_parallel_group ,
632- tensor_parallel_group = expert_tensor_parallel_group ,
633- expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (),
615+ mcore_parallel .get_expert_data_parallel_group (check_initialized = False ),
616+ tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group (
617+ check_initialized = False
618+ ),
619+ expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (
620+ check_initialized = False
621+ ),
634622 )
623+ # initialize parallel state for submodules linear_fc1 and linear_fc2
624+ self .linear_fc1 .parallel_state = self .parallel_state
625+ self .linear_fc2 .parallel_state = self .parallel_state
635626
636627
637628# Register the public megatron_moe.SequentialMLP class
638629@QuantModuleRegistry .register ({megatron_moe .SequentialMLP : "megatron_moe_SequentialMLP" })
639630class _MegatronSequentialMLP (_MegatronMLP ):
640631 def _setup (self ):
641632 if not hasattr (self , "parallel_state" ) or self .parallel_state is None :
642- try :
643- data_parallel_group = mcore_parallel .get_expert_data_parallel_group ()
644- except AssertionError :
645- data_parallel_group = None
646-
647- try :
648- expert_tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group ()
649- except AssertionError :
650- expert_tensor_parallel_group = None
651633 self .parallel_state = ParallelState (
652- data_parallel_group ,
653- tensor_parallel_group = expert_tensor_parallel_group ,
654- expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (),
634+ mcore_parallel .get_expert_data_parallel_group (check_initialized = False ),
635+ tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group (
636+ check_initialized = False
637+ ),
638+ expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (
639+ check_initialized = False
640+ ),
655641 )
642+
643+ # Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2
644+ for expert in self .local_experts :
645+ expert .linear_fc1 .parallel_state = self .parallel_state
646+ expert .linear_fc2 .parallel_state = self .parallel_state
0 commit comments