2424import megatron .core .transformer .mlp as megatron_mlp
2525import megatron .core .transformer .moe .experts as megatron_moe
2626import torch
27- import transformer_engine .pytorch .module .grouped_linear as te_grouped_linear
28- from megatron .core .extensions import transformer_engine as megatron_te
2927from megatron .core .parallel_state import get_data_parallel_group
3028from megatron .core .tensor_parallel .mappings import gather_from_sequence_parallel_region
3129from megatron .core .transformer import MegatronModule
4139from ..nn import QuantModuleRegistry , TensorQuantizer
4240from ..nn .modules .quant_linear import RealQuantLinear
4341from ..qtensor import QTensorWrapper
44- from .custom import CUSTOM_MODEL_PLUGINS , CUSTOM_POST_CALIBRATION_PLUGINS , _ParallelLinear
42+ from .custom import CUSTOM_MODEL_PLUGINS , _ParallelLinear
4543
46- logger = logging .getLogger (__name__ )
47-
48- __all__ = []
49-
50-
51- def sync_amax_across_sequential_mlp (model : torch .nn .Module ):
52- """Sync amax across experts in a SequentialMLP."""
53- amax_dict = {}
54-
55- def get_sequential_mlp_expert_names (name : str , module : torch .nn .Module ):
56- if (
57- isinstance (module , TensorQuantizer )
58- and hasattr (module , "_amax" )
59- and ".local_experts." in name
60- ):
61- expert_name , local_expert_name = name .split (".local_experts." )
62- # extract quantizer name by removing local_expert number from the name
63- local_expert_name = "." .join (local_expert_name .split ("." )[1 :])
64- return f"{ expert_name } .{ local_expert_name } "
65- return None
44+ try :
45+ from megatron .core .extensions .transformer_engine import (
46+ TEColumnParallelGroupedLinear ,
47+ TERowParallelGroupedLinear ,
48+ )
6649
67- # gather amax values from SequentialMLP experts
68- for name , module in model .named_modules ():
69- expert_name = get_sequential_mlp_expert_names (name , module )
70- if expert_name and module .amax is not None :
71- stored_amax = amax_dict .get (expert_name )
72- amax_tensor = module .amax .detach ().clone ()
73- amax_dict [expert_name ] = (
74- amax_tensor if stored_amax is None else torch .maximum (stored_amax , amax_tensor )
75- )
50+ from .transformer_engine import _QuantTEGroupedLinear
7651
77- # sync amax values across experts in SequentialMLP
78- for name , module in model .named_modules ():
79- expert_name = get_sequential_mlp_expert_names (name , module )
80- if expert_name and module .amax is not None :
81- module .amax = amax_dict [expert_name ].detach ().clone ().to (module .amax .device )
52+ HAS_TE = True
53+ except ImportError :
54+ HAS_TE = False
8255
56+ logger = logging .getLogger (__name__ )
8357
84- CUSTOM_POST_CALIBRATION_PLUGINS . add ( sync_amax_across_sequential_mlp )
58+ __all__ = []
8559
8660
8761def real_quant_module_get_extra_state (self ) -> dict :
@@ -516,111 +490,6 @@ def forward(self, input, *args, **kwargs):
516490 return _MegatronRowParallelLinear .forward (self , input , * args , ** kwargs )
517491
518492
519- # Register the public te.pytorch.GroupedLinear class
520- @QuantModuleRegistry .register ({te_grouped_linear .GroupedLinear : "te_GroupedLinear" })
521- class _QuantMegatronTEGroupedLinear (_MegatronParallelLinear ):
522- _functionals_to_replace = [
523- (te_grouped_linear ._GroupedLinear , "forward" ),
524- (te_grouped_linear ._GroupedLinear , "apply" ),
525- ]
526-
527- def _setup (self ):
528- # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
529- # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
530- # self.weight0 to self.weight to run the quantizer states initialization.
531- self .weight = self .weight0
532- # Memorize the original weight.dtype for modelopt_post_restore given that
533- # the dtype can change later.
534- super ()._setup ()
535- # Remove self.weight after setup.
536- delattr (self , "weight" )
537-
538- def modelopt_post_restore (self , prefix : str = "" ):
539- # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
540- # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
541- # self.weight0 to self.weight to run the quantizer states initialization.
542- self .weight = self .weight0
543- super ().modelopt_post_restore (prefix = prefix )
544- # Remove self.weight after post_restore.
545- delattr (self , "weight" )
546-
547- def _load_from_state_dict (self , state_dict , prefix , * args , ** kwargs ):
548- # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
549- # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
550- # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
551- # for modelopt checkpoint restore
552- filtered_state_dict = {
553- k : v
554- for k , v in state_dict .items ()
555- if not any (k .endswith (f"_extra_state{ num } " ) for num in range (1 , self .num_gemms ))
556- }
557- return super ()._load_from_state_dict (filtered_state_dict , prefix , * args , ** kwargs )
558-
559- def _process_quantizer_amax (self , k , v , quantizer_state_dict ):
560- assert v .numel () == 1 , "TEGroupedLinear only supports per-tensor quantization"
561- quantizer_state_dict [k ] = v .view (- 1 )
562-
563- @staticmethod
564- def te_grouped_quantized_linear_fn (package , func_name , self , * args ):
565- idx = 1 if func_name == "_forward" else 0
566- inp = args [idx ]
567- num_gemms = len (args [idx + 1 ])
568- weights_and_biases = args [- 2 * num_gemms :]
569- weights , biases = weights_and_biases [:num_gemms ], weights_and_biases [num_gemms :]
570- quantized_inputs = self .input_quantizer (inp )
571- quantized_weights = [self .weight_quantizer (weight ) for weight in weights ]
572-
573- output = getattr (package , func_name )(
574- * (
575- args [0 ],
576- quantized_inputs ,
577- )
578- if func_name == "_forward"
579- else (quantized_inputs ,),
580- * args [idx + 1 : - 2 * num_gemms ],
581- * quantized_weights ,
582- * biases ,
583- )
584- return self .output_quantizer (output )
585-
586- # Override the quantized linear function
587- _quantized_linear_fn = te_grouped_quantized_linear_fn
588-
589-
590- @QuantModuleRegistry .register (
591- {megatron_te .TEColumnParallelGroupedLinear : "megatron_TEColumnParallelGroupedLinear" }
592- )
593- class _MegatronTEGroupedColumnParallelLinear (
594- _QuantMegatronTEGroupedLinear , _MegatronColumnParallelLinear
595- ):
596- _is_column_parallel = True
597-
598-
599- @QuantModuleRegistry .register (
600- {megatron_te .TERowParallelGroupedLinear : "megatron_TERowParallelGroupedLinear" }
601- )
602- class _MegatronTEGroupedRowParallelLinear (
603- _QuantMegatronTEGroupedLinear , _MegatronRowParallelLinear
604- ):
605- _is_row_parallel = True
606-
607-
608- # Register the public megatron_moe.TEGroupedMLP class
609- @QuantModuleRegistry .register ({megatron_moe .TEGroupedMLP : "megatron_moe_TEGroupedMLP" })
610- class _MegatronTEGroupedMLP (_MegatronMLP ):
611- def _setup (self ):
612- if not hasattr (self , "parallel_state" ) or self .parallel_state is None :
613- self .parallel_state = ParallelState (
614- mcore_parallel .get_expert_data_parallel_group (),
615- tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group (),
616- expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (),
617- )
618- # initialize parallel state for submodules linear_fc1 and linear_fc2
619- self .linear_fc1 .parallel_state = self .parallel_state
620- self .linear_fc2 .parallel_state = self .parallel_state
621-
622-
623- # Register the public megatron_moe.SequentialMLP class
624493@QuantModuleRegistry .register ({megatron_moe .SequentialMLP : "megatron_moe_SequentialMLP" })
625494class _MegatronSequentialMLP (_MegatronMLP ):
626495 def _setup (self ):
@@ -635,3 +504,73 @@ def _setup(self):
635504 for expert in self .local_experts :
636505 expert .linear_fc1 .parallel_state = self .parallel_state
637506 expert .linear_fc2 .parallel_state = self .parallel_state
507+
508+ def sync_moe_local_experts_amax (self ):
509+ """Sync amax across experts in a SequentialMLP."""
510+ amax_dict = {}
511+ # gather amax values from SequentialMLP experts
512+ for expert in self .local_experts :
513+ for name , module in expert .named_modules ():
514+ if isinstance (module , TensorQuantizer ) and module .amax is not None :
515+ stored_amax = amax_dict .get (name )
516+ amax_tensor = module .amax .detach ().clone ()
517+ amax_dict [name ] = (
518+ amax_tensor
519+ if stored_amax is None
520+ else torch .maximum (stored_amax , amax_tensor )
521+ )
522+
523+ # sync amax values across experts in SequentialMLP
524+ for expert in self .local_experts :
525+ for name , module in expert .named_modules ():
526+ if isinstance (module , TensorQuantizer ) and module .amax is not None :
527+ module .amax = amax_dict [name ].detach ().clone ().to (module .amax .device )
528+
529+
530+ if HAS_TE :
531+ # Quantized subclasses to support TEGroupedMLP quantization
532+ class _QuantMegatronTEGroupedLinear (_QuantTEGroupedLinear , _MegatronParallelLinear ):
533+ def _load_from_state_dict (self , state_dict , prefix , * args , ** kwargs ):
534+ # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
535+ # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
536+ # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
537+ # for modelopt checkpoint restore
538+ filtered_state_dict = {
539+ k : v
540+ for k , v in state_dict .items ()
541+ if not any (k .endswith (f"_extra_state{ num } " ) for num in range (1 , self .num_gemms ))
542+ }
543+ return super ()._load_from_state_dict (filtered_state_dict , prefix , * args , ** kwargs )
544+
545+ def _process_quantizer_amax (self , k , v , quantizer_state_dict ):
546+ assert v .numel () == 1 , "TEGroupedLinear only supports per-tensor quantization"
547+ quantizer_state_dict [k ] = v .view (- 1 )
548+
549+ @QuantModuleRegistry .register (
550+ {TEColumnParallelGroupedLinear : "megatron_TEColumnParallelGroupedLinear" }
551+ )
552+ class _MegatronTEGroupedColumnParallelLinear (
553+ _QuantMegatronTEGroupedLinear , _MegatronColumnParallelLinear
554+ ):
555+ pass
556+
557+ @QuantModuleRegistry .register (
558+ {TERowParallelGroupedLinear : "megatron_TERowParallelGroupedLinear" }
559+ )
560+ class _MegatronTEGroupedRowParallelLinear (
561+ _QuantMegatronTEGroupedLinear , _MegatronRowParallelLinear
562+ ):
563+ pass
564+
565+ @QuantModuleRegistry .register ({megatron_moe .TEGroupedMLP : "megatron_moe_TEGroupedMLP" })
566+ class _MegatronTEGroupedMLP (_MegatronMLP ):
567+ def _setup (self ):
568+ if not hasattr (self , "parallel_state" ) or self .parallel_state is None :
569+ self .parallel_state = ParallelState (
570+ mcore_parallel .get_expert_data_parallel_group (),
571+ tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group (),
572+ expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (),
573+ )
574+ # initialize parallel state for submodules linear_fc1 and linear_fc2
575+ self .linear_fc1 .parallel_state = self .parallel_state
576+ self .linear_fc2 .parallel_state = self .parallel_state
0 commit comments