-
Notifications
You must be signed in to change notification settings - Fork 202
Added support for quantizing TEGroupedMLP for megatron-lm #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
9cac53c
be5e838
4a2a8d7
cacee61
41cc9bd
4a706ef
7b2c969
e6dc5e5
f17320e
a1fdf18
cd31159
5a67acf
9d7dff1
3bf16e6
70776c3
bab9ca2
f9ba6e8
a917c2b
1ea4ed1
169677c
153e376
5bc99e0
23daf38
15ffb87
28c8bbf
5481d10
91837c3
ca55348
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change looks good! |
kinjalpatel27 marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,10 @@ | |
| import megatron.core.parallel_state as mcore_parallel | ||
| import megatron.core.tensor_parallel.layers as megatron_parallel | ||
| import megatron.core.transformer.mlp as megatron_mlp | ||
| import megatron.core.transformer.moe.experts as megatron_moe | ||
| import torch | ||
| import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear | ||
| from megatron.core.extensions import transformer_engine as megatron_te | ||
| from megatron.core.parallel_state import get_data_parallel_group | ||
| from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region | ||
| from megatron.core.transformer import MegatronModule | ||
|
|
@@ -38,13 +41,49 @@ | |
| from ..nn import QuantModuleRegistry, TensorQuantizer | ||
| from ..nn.modules.quant_linear import RealQuantLinear | ||
| from ..qtensor import QTensorWrapper | ||
| from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear | ||
| from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| __all__ = [] | ||
|
|
||
|
|
||
| def sync_amax_across_sequential_mlp(model: torch.nn.Module): | ||
|
||
| """Sync amax across experts in a SequentialMLP.""" | ||
| amax_dict = {} | ||
|
|
||
| def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module): | ||
| if ( | ||
| isinstance(module, TensorQuantizer) | ||
| and hasattr(module, "_amax") | ||
| and ".local_experts." in name | ||
| ): | ||
| expert_name, local_expert_name = name.split(".local_experts.") | ||
| # extract quantizer name by removing local_expert number from the name | ||
| local_expert_name = ".".join(local_expert_name.split(".")[1:]) | ||
| return f"{expert_name}.{local_expert_name}" | ||
| return None | ||
|
|
||
| # gather amax values from SequentialMLP experts | ||
| for name, module in model.named_modules(): | ||
| expert_name = get_sequential_mlp_expert_names(name, module) | ||
| if expert_name and module.amax is not None: | ||
| stored_amax = amax_dict.get(expert_name) | ||
| amax_tensor = module.amax.detach().clone() | ||
| amax_dict[expert_name] = ( | ||
| amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor) | ||
| ) | ||
|
|
||
| # sync amax values across experts in SequentialMLP | ||
| for name, module in model.named_modules(): | ||
| expert_name = get_sequential_mlp_expert_names(name, module) | ||
| if expert_name and module.amax is not None: | ||
| module.amax = amax_dict[expert_name].detach().clone().to(module.amax.device) | ||
|
|
||
|
|
||
| CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp) | ||
|
|
||
|
|
||
| def real_quant_module_get_extra_state(self) -> dict: | ||
| """Populating real_quantizer_state and q_tensor_state.""" | ||
| extra_state = {} | ||
|
|
@@ -221,16 +260,19 @@ class _MegatronParallelLinear(_ParallelLinear): | |
| ] | ||
|
|
||
| def _setup(self): | ||
| data_parallel_group = None | ||
| try: | ||
| data_parallel_group = get_data_parallel_group(with_context_parallel=True) | ||
| except AssertionError: | ||
| logger.warning("Context parallel group is not initialized, using data parallel group") | ||
| data_parallel_group = get_data_parallel_group() | ||
| self.parallel_state = ParallelState( | ||
| data_parallel_group, | ||
| mcore_parallel.get_tensor_model_parallel_group(), | ||
| ) | ||
| if not hasattr(self, "parallel_state") or self.parallel_state is None: | ||
| data_parallel_group = None | ||
| try: | ||
| data_parallel_group = get_data_parallel_group(with_context_parallel=True) | ||
| except AssertionError: | ||
| logger.warning( | ||
| "Context parallel group is not initialized, using data parallel group" | ||
| ) | ||
| data_parallel_group = get_data_parallel_group() | ||
| self.parallel_state = ParallelState( | ||
| data_parallel_group, | ||
| mcore_parallel.get_tensor_model_parallel_group(), | ||
| ) | ||
| super()._setup() | ||
|
|
||
| def _process_quantizer_amax(self, k, v, quantizer_state_dict): | ||
|
|
@@ -472,3 +514,124 @@ class _RealQuantMegatronRowParallelLinear( | |
|
|
||
| def forward(self, input, *args, **kwargs): | ||
| return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) | ||
|
|
||
|
|
||
| # Register the public te.pytorch.GroupedLinear class | ||
| @QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) | ||
| class _QuantMegatronTEGroupedLinear(_MegatronParallelLinear): | ||
kinjalpatel27 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| _functionals_to_replace = [ | ||
| (te_grouped_linear._GroupedLinear, "forward"), | ||
| (te_grouped_linear._GroupedLinear, "apply"), | ||
| ] | ||
|
|
||
| def _setup(self): | ||
| # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to | ||
| # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning | ||
| # self.weight0 to self.weight to run the quantizer states initialization. | ||
| self.weight = self.weight0 | ||
| # Memorize the original weight.dtype for modelopt_post_restore given that | ||
| # the dtype can change later. | ||
| super()._setup() | ||
| # Remove self.weight after setup. | ||
| delattr(self, "weight") | ||
|
|
||
| def modelopt_post_restore(self, prefix: str = ""): | ||
kinjalpatel27 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to | ||
| # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning | ||
| # self.weight0 to self.weight to run the quantizer states initialization. | ||
| self.weight = self.weight0 | ||
| super().modelopt_post_restore(prefix=prefix) | ||
| # Remove self.weight after post_restore. | ||
| delattr(self, "weight") | ||
|
|
||
| def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | ||
| # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in | ||
| # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for | ||
| # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] | ||
| # for modelopt checkpoint restore | ||
| filtered_state_dict = { | ||
| k: v | ||
| for k, v in state_dict.items() | ||
| if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) | ||
| } | ||
| return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) | ||
|
|
||
| def _process_quantizer_amax(self, k, v, quantizer_state_dict): | ||
| assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization" | ||
| quantizer_state_dict[k] = v.view(-1) | ||
|
|
||
| @staticmethod | ||
| def te_grouped_quantized_linear_fn(package, func_name, self, *args): | ||
| idx = 1 if func_name == "_forward" else 0 | ||
| inp = args[idx] | ||
| num_gemms = len(args[idx + 1]) | ||
| weights_and_biases = args[-2 * num_gemms :] | ||
| weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] | ||
| quantized_inputs = self.input_quantizer(inp) | ||
| quantized_weights = [self.weight_quantizer(weight) for weight in weights] | ||
|
|
||
| output = getattr(package, func_name)( | ||
| *( | ||
| args[0], | ||
| quantized_inputs, | ||
| ) | ||
| if func_name == "_forward" | ||
| else (quantized_inputs,), | ||
| *args[idx + 1 : -2 * num_gemms], | ||
| *quantized_weights, | ||
| *biases, | ||
| ) | ||
| return self.output_quantizer(output) | ||
|
|
||
| # Override the quantized linear function | ||
| _quantized_linear_fn = te_grouped_quantized_linear_fn | ||
|
|
||
|
|
||
| @QuantModuleRegistry.register( | ||
| {megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} | ||
| ) | ||
| class _MegatronTEGroupedColumnParallelLinear( | ||
| _QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear | ||
| ): | ||
| _is_column_parallel = True | ||
kinjalpatel27 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @QuantModuleRegistry.register( | ||
| {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} | ||
| ) | ||
| class _MegatronTEGroupedRowParallelLinear( | ||
| _QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear | ||
| ): | ||
| _is_row_parallel = True | ||
kinjalpatel27 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| # Register the public megatron_moe.TEGroupedMLP class | ||
| @QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"}) | ||
| class _MegatronTEGroupedMLP(_MegatronMLP): | ||
| def _setup(self): | ||
| if not hasattr(self, "parallel_state") or self.parallel_state is None: | ||
| self.parallel_state = ParallelState( | ||
| mcore_parallel.get_expert_data_parallel_group(), | ||
| tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), | ||
| expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), | ||
| ) | ||
| # initialize parallel state for submodules linear_fc1 and linear_fc2 | ||
| self.linear_fc1.parallel_state = self.parallel_state | ||
| self.linear_fc2.parallel_state = self.parallel_state | ||
|
|
||
kinjalpatel27 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Register the public megatron_moe.SequentialMLP class | ||
| @QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"}) | ||
| class _MegatronSequentialMLP(_MegatronMLP): | ||
| def _setup(self): | ||
| if not hasattr(self, "parallel_state") or self.parallel_state is None: | ||
| self.parallel_state = ParallelState( | ||
kinjalpatel27 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mcore_parallel.get_expert_data_parallel_group(), | ||
| tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), | ||
| expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), | ||
| ) | ||
|
|
||
| # Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2 | ||
| for expert in self.local_experts: | ||
| expert.linear_fc1.parallel_state = self.parallel_state | ||
| expert.linear_fc2.parallel_state = self.parallel_state | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not need
register_custom_post_calibration_plugins. Lets not introduce new infrastructure un-necessarily.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the point of post_calibration plugins now. Let's keep them as we discussed.