-
Notifications
You must be signed in to change notification settings - Fork 202
Added support for quantizing TEGroupedMLP for megatron-lm #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9cac53c
be5e838
4a2a8d7
cacee61
41cc9bd
4a706ef
7b2c969
e6dc5e5
f17320e
a1fdf18
cd31159
5a67acf
9d7dff1
3bf16e6
70776c3
bab9ca2
f9ba6e8
a917c2b
1ea4ed1
169677c
153e376
5bc99e0
23daf38
15ffb87
28c8bbf
5481d10
91837c3
ca55348
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis | |
| if not distributed_sync: | ||
| return | ||
|
|
||
| def sync_quantizer_amax_across_dp(quantizer, parallel_state): | ||
| """Synchronize the amax across all ranks in the data parallel group.""" | ||
| def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): | ||
| """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" | ||
| if isinstance(quantizer, SequentialQuantizer): | ||
| for _q in quantizer: | ||
| sync_quantizer_amax_across_dp(_q, parallel_state) | ||
| sync_quantizer_amax_across_dp_ep(_q, parallel_state) | ||
| return | ||
| if getattr(quantizer, "_amax", None) is not None: | ||
| quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) | ||
| quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) | ||
| # TODO: create sync_bias_across_distributed_group | ||
|
|
||
| for name, module in model.named_modules(): | ||
| if isinstance(module, QuantModule): | ||
| for child in module.children(): | ||
| if isinstance(child, (TensorQuantizer, SequentialQuantizer)): | ||
| sync_quantizer_amax_across_dp(child, module.parallel_state) | ||
| sync_quantizer_amax_across_dp_ep(child, module.parallel_state) | ||
| # TP sync: | ||
| # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same | ||
|
|
||
|
|
@@ -117,6 +118,7 @@ def sync_quantizer_amax_across_tp( | |
| # Syncing amax across TP for sequential quantizer | ||
| if isinstance(quantizer, SequentialQuantizer): | ||
| for _q in quantizer: | ||
| # Syncing amax across TP for sequential quantizer | ||
| sync_quantizer_amax_across_tp( | ||
| _q, linear_name, quantizer_type, axes_for_sync, parallel_state | ||
| ) | ||
|
|
@@ -174,6 +176,10 @@ def sync_quantizer_amax_across_tp( | |
| parallel_state=module.parallel_state, | ||
| ) | ||
|
|
||
| for name, module in model.named_modules(): | ||
| if hasattr(module, "sync_moe_local_experts_amax"): | ||
| module.sync_moe_local_experts_amax() | ||
|
|
||
|
Comment on lines
+180
to
+182
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard MOE expert sync behind an initialized process group
- for name, module in model.named_modules():
- if hasattr(module, "sync_moe_local_experts_amax"):
- module.sync_moe_local_experts_amax()
+ if dist.is_available() and dist.is_initialized():
+ for name, module in model.named_modules():
+ if hasattr(module, "sync_moe_local_experts_amax"):
+ module.sync_moe_local_experts_amax()🤖 Prompt for AI Agents |
||
|
|
||
| def enable_stats_collection(model: nn.Module): | ||
| """Enable stats collection for all quantizers in the model.""" | ||
|
|
||
kinjalpatel27 marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| import torch | ||
| import transformer_engine as te | ||
| import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear | ||
| import transformer_engine.pytorch.module.linear as te_linear | ||
|
|
||
| from ..nn import QuantModuleRegistry | ||
|
|
@@ -58,3 +59,60 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs): | |
|
|
||
| # Override the quantized linear function | ||
| _quantized_linear_fn = te_quantized_linear_fn | ||
|
|
||
|
|
||
| # Register the public te.pytorch.GroupedLinear class | ||
| @QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) | ||
| class _QuantTEGroupedLinear(_ParallelLinear): | ||
| _functionals_to_replace = [ | ||
| (te_grouped_linear._GroupedLinear, "forward"), | ||
| (te_grouped_linear._GroupedLinear, "apply"), | ||
| ] | ||
|
|
||
| def _setup(self): | ||
| # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to | ||
| # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning | ||
| # self.weight0 to self.weight to run the quantizer states initialization. | ||
| assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" | ||
| self.weight = self.weight0 | ||
kinjalpatel27 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Memorize the original weight.dtype for modelopt_post_restore given that | ||
| # the dtype can change later. | ||
| super()._setup() | ||
| # Remove self.weight after setup. | ||
| delattr(self, "weight") | ||
|
|
||
| def modelopt_post_restore(self, prefix: str = ""): | ||
| # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to | ||
| # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning | ||
| # self.weight0 to self.weight to run the quantizer states initialization. | ||
| assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" | ||
| self.weight = self.weight0 | ||
| super().modelopt_post_restore(prefix=prefix) | ||
| # Remove self.weight after post_restore. | ||
| delattr(self, "weight") | ||
|
|
||
| @staticmethod | ||
| def te_grouped_quantized_linear_fn(package, func_name, self, *args): | ||
| idx = 1 if func_name == "_forward" else 0 | ||
| inp = args[idx] | ||
| num_gemms = len(args[idx + 1]) | ||
| weights_and_biases = args[-2 * num_gemms :] | ||
| weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] | ||
| quantized_inputs = self.input_quantizer(inp) | ||
| quantized_weights = [self.weight_quantizer(weight) for weight in weights] | ||
|
|
||
| output = getattr(package, func_name)( | ||
| *( | ||
| args[0], | ||
| quantized_inputs, | ||
| ) | ||
| if func_name == "_forward" | ||
| else (quantized_inputs,), | ||
| *args[idx + 1 : -2 * num_gemms], | ||
| *quantized_weights, | ||
| *biases, | ||
| ) | ||
| return self.output_quantizer(output) | ||
|
Comment on lines
+72
to
+115
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Expose a stable With def _setup(self):
- # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
- # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
- # self.weight0 to self.weight to run the quantizer states initialization.
- assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
- self.weight = self.weight0
+ # GroupedMLP stores the weights as weight0, weight1, etc. Use weight0 to drive quantizer setup.
+ assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+ self.weight = self.weight0
# Memorize the original weight.dtype for modelopt_post_restore given that
# the dtype can change later.
super()._setup()
- # Remove self.weight after setup.
- delattr(self, "weight")
+ # Setter below is a no-op so we do not register a duplicate Parameter named "weight".
@@
def modelopt_post_restore(self, prefix: str = ""):
- # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
- # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
- # self.weight0 to self.weight to run the quantizer states initialization.
- assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear"
- self.weight = self.weight0
+ # GroupedMLP stores the weights as weight0, weight1, etc. Reuse weight0 to drive post_restore.
+ assert "weight" not in self._parameters, "self.weight should not exist for TEGroupedLinear"
+ self.weight = self.weight0
super().modelopt_post_restore(prefix=prefix)
- # Remove self.weight after post_restore.
- delattr(self, "weight")
+ # Setter below keeps weight0 as the canonical tensor.
+
+ @property
+ def weight(self):
+ return self.weight0
+
+ @weight.setter
+ def weight(self, value):
+ if value is not self.weight0:
+ raise ValueError("TEGroupedLinear expects weight0 to back the canonical weight parameter.")🤖 Prompt for AI Agents |
||
|
|
||
| # Override the quantized linear function | ||
| _quantized_linear_fn = te_grouped_quantized_linear_fn | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change looks good!