-
Notifications
You must be signed in to change notification settings - Fork 200
Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
f17131f
42519cc
264adbb
7cbe5b9
1f7d17e
71a9f7a
d02365c
5a572da
fc0bb88
95da832
34c11ef
10e3e2b
9f0691f
fa8f4c8
d1fac44
22b8b73
ca7c0e8
3f857a3
93bfd52
6761109
291cfa3
a106dd9
50000dd
2664563
440ca48
2e8ef58
5cb380c
afe6f34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -81,6 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis | |||||||||||||||||||||
| return | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def sync_quantizer_amax_across_dp(quantizer, parallel_state): | ||||||||||||||||||||||
| """Synchronize the amax across all ranks in the data parallel group.""" | ||||||||||||||||||||||
| if isinstance(quantizer, SequentialQuantizer): | ||||||||||||||||||||||
| for _q in quantizer: | ||||||||||||||||||||||
| sync_quantizer_amax_across_dp(_q, parallel_state) | ||||||||||||||||||||||
|
|
@@ -94,7 +95,6 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state): | |||||||||||||||||||||
| for child in module.children(): | ||||||||||||||||||||||
| if isinstance(child, (TensorQuantizer, SequentialQuantizer)): | ||||||||||||||||||||||
| sync_quantizer_amax_across_dp(child, module.parallel_state) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # TP sync: | ||||||||||||||||||||||
| # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -114,6 +114,7 @@ def sync_quantizer_amax_across_tp( | |||||||||||||||||||||
| axes_for_sync: list, | ||||||||||||||||||||||
| parallel_state: ParallelState, | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| # Syncing amax across TP for sequential quantizer | ||||||||||||||||||||||
| if isinstance(quantizer, SequentialQuantizer): | ||||||||||||||||||||||
| for _q in quantizer: | ||||||||||||||||||||||
| sync_quantizer_amax_across_tp( | ||||||||||||||||||||||
|
|
@@ -598,19 +599,41 @@ def forward(self, input, *args, **kwargs): | |||||||||||||||||||||
| # This will also perform distributed amax sync for input_quantizers | ||||||||||||||||||||||
| max_calibrate(model, lambda model: None) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def sync_act_scale_across_dp(module, data_parallel_group): | ||||||||||||||||||||||
| """Sync activation scale across Data Parallel (DP).""" | ||||||||||||||||||||||
| if data_parallel_group.is_initialized(): | ||||||||||||||||||||||
| dist.all_reduce( | ||||||||||||||||||||||
| module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| for name, module in model.named_modules(): | ||||||||||||||||||||||
| if ( | ||||||||||||||||||||||
| is_quantized_linear(module) | ||||||||||||||||||||||
| and hasattr(module, "awq_lite") | ||||||||||||||||||||||
| and module.awq_lite.num_cache_steps > 0 | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| # Hack: MoEs forward all tokens through all experts if _if_calib is True | ||||||||||||||||||||||
| module._if_calib = True | ||||||||||||||||||||||
| module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps | ||||||||||||||||||||||
| if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( | ||||||||||||||||||||||
| torch.isnan(module.awq_lite.weight_scale) | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
Comment on lines
+619
to
+621
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix tensor boolean evaluation before distributed sync.
- has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
- torch.isnan(module.awq_lite.weight_scale)
- )
+ act_nan = torch.isnan(module.awq_lite.act_scale).any().item()
+ weight_nan = torch.isnan(module.awq_lite.weight_scale).any().item()
+ has_nan_local = bool(act_nan or weight_nan)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||
| has_nan = torch.tensor(int(has_nan_local), device=module.weight.device) | ||||||||||||||||||||||
| if module.parallel_state.data_parallel_group.is_initialized(): | ||||||||||||||||||||||
| dist.all_reduce( | ||||||||||||||||||||||
| has_nan, | ||||||||||||||||||||||
| op=dist.ReduceOp.MAX, | ||||||||||||||||||||||
| group=module.parallel_state.data_parallel_group.group, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if has_nan.item() > 0: | ||||||||||||||||||||||
| module.awq_lite.is_enabled = False | ||||||||||||||||||||||
| # Hack: MoEs forward all tokens through all experts if _if_calib is True | ||||||||||||||||||||||
| module._if_calib = True | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| sync_act_scale_across_dp( | ||||||||||||||||||||||
| module, | ||||||||||||||||||||||
| module.parallel_state.data_parallel_group, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| AWQLiteHelper.cache_mode = False | ||||||||||||||||||||||
| print_rank_0("awq_lite: Searching parameters...") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we dont need separate methods |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||
| import copy | ||||||||||||||||||||||||||||||||||||
| from unittest.mock import patch | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||
|
|
@@ -22,7 +23,9 @@ | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import modelopt.torch.opt as mto | ||||||||||||||||||||||||||||||||||||
| import modelopt.torch.quantization as mtq | ||||||||||||||||||||||||||||||||||||
| import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite | ||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm | ||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer | ||||||||||||||||||||||||||||||||||||
| from modelopt.torch.quantization.utils import is_quantized_linear | ||||||||||||||||||||||||||||||||||||
| from modelopt.torch.utils import torch_to | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -116,38 +119,91 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N | |||||||||||||||||||||||||||||||||||
| mto.restore_from_modelopt_state(model_ref, state_dict) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def tensor_parallel_test_helper(model, config, tp_group, dp_group): | ||||||||||||||||||||||||||||||||||||
| # The input to fist layer, the column parallel should be the same across all tp ranks | ||||||||||||||||||||||||||||||||||||
| calib_data = model.get_dummy_input().cuda() | ||||||||||||||||||||||||||||||||||||
| dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) | ||||||||||||||||||||||||||||||||||||
| def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]): | ||||||||||||||||||||||||||||||||||||
| for group in groups: | ||||||||||||||||||||||||||||||||||||
| if group is not None: | ||||||||||||||||||||||||||||||||||||
| quantizer_attr = getattr(quantizer, attr).clone() | ||||||||||||||||||||||||||||||||||||
| dist.all_reduce(quantizer_attr, op=op, group=group) | ||||||||||||||||||||||||||||||||||||
| assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
122
to
127
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix distributed attribute check to cover all groups
-def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
- for group in groups:
- if group is not None:
- quantizer_attr = getattr(quantizer, attr).clone()
- dist.all_reduce(quantizer_attr, op=op, group=group)
- assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
+def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=()):
+ attr_value = getattr(quantizer, attr)
+ checked = False
+ for group in groups:
+ if group is None:
+ continue
+ reduced = attr_value.clone()
+ dist.all_reduce(reduced, op=op, group=group)
+ assert torch.allclose(reduced, attr_value)
+ checked = True
+ assert checked, "expected at least one distributed group to validate"📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def forward_loop(model): | ||||||||||||||||||||||||||||||||||||
| model(calib_data) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| model = mtq.quantize(model, config, forward_loop) | ||||||||||||||||||||||||||||||||||||
| original_awq_lite = model_calib_module.awq_lite | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Sanity check | ||||||||||||||||||||||||||||||||||||
| forward_loop(model) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: | ||||||||||||||||||||||||||||||||||||
| # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks | ||||||||||||||||||||||||||||||||||||
| activation_amax = model.fc2.input_quantizer.amax.clone() | ||||||||||||||||||||||||||||||||||||
| dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||||||||||||
| assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) | ||||||||||||||||||||||||||||||||||||
| def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs): | ||||||||||||||||||||||||||||||||||||
| """Function to mock awq_lite function to always use debug=True for testing""" | ||||||||||||||||||||||||||||||||||||
| return original_awq_lite(model, forward_loop, alpha_step, debug=True, **kwargs) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Lets check the row parallel weight amax; it should be the same across all tp ranks | ||||||||||||||||||||||||||||||||||||
| weight_amax = model.fc2.weight_quantizer.amax.clone() | ||||||||||||||||||||||||||||||||||||
| dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||||||||||||
| assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||||||||||||
| # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks | ||||||||||||||||||||||||||||||||||||
| input_quantizer = model.fc1.input_quantizer | ||||||||||||||||||||||||||||||||||||
| pre_quant_scale = input_quantizer.pre_quant_scale.clone() | ||||||||||||||||||||||||||||||||||||
| dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||||||||||||
| assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) | ||||||||||||||||||||||||||||||||||||
| @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||||||||||||||||||||||||||||||||||||
| def data_tensor_context_parallel_test_helper( | ||||||||||||||||||||||||||||||||||||
| model, config, mock_awq_lite, dp_group=None, tp_group=None | ||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||
| # Calib data should be different across each DP rank | ||||||||||||||||||||||||||||||||||||
| dp_rank = dist.get_rank(group=dp_group) | ||||||||||||||||||||||||||||||||||||
| calib_data = model.get_dummy_input(seed=dp_rank).cuda() | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if tp_group is not None: | ||||||||||||||||||||||||||||||||||||
| # The input to first layer, the column parallel should be the same across all tp ranks | ||||||||||||||||||||||||||||||||||||
| dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| dist.destroy_process_group() | ||||||||||||||||||||||||||||||||||||
| def forward_loop(model): | ||||||||||||||||||||||||||||||||||||
| model(calib_data) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| model = mtq.quantize(model, config, forward_loop) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Input quantizer amax | ||||||||||||||||||||||||||||||||||||
| if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: | ||||||||||||||||||||||||||||||||||||
| _distributed_attr_check( | ||||||||||||||||||||||||||||||||||||
| model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| _distributed_attr_check( | ||||||||||||||||||||||||||||||||||||
| model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks | ||||||||||||||||||||||||||||||||||||
| # Channel-wise (INT8) only expects same amax across row parallel ranks | ||||||||||||||||||||||||||||||||||||
| # Block-wise quantization does not expect same amax across row and column parallel ranks | ||||||||||||||||||||||||||||||||||||
| if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: | ||||||||||||||||||||||||||||||||||||
| if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||||||||||||
| for quantizer in model.fc1.weight_quantizer: | ||||||||||||||||||||||||||||||||||||
| _distributed_attr_check( | ||||||||||||||||||||||||||||||||||||
| quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| _distributed_attr_check( | ||||||||||||||||||||||||||||||||||||
| model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if config in [ | ||||||||||||||||||||||||||||||||||||
| mtq.FP8_DEFAULT_CFG, | ||||||||||||||||||||||||||||||||||||
| mtq.NVFP4_DEFAULT_CFG, | ||||||||||||||||||||||||||||||||||||
| mtq.INT8_DEFAULT_CFG, | ||||||||||||||||||||||||||||||||||||
| mtq.INT8_SMOOTHQUANT_CFG, | ||||||||||||||||||||||||||||||||||||
| ]: | ||||||||||||||||||||||||||||||||||||
| if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||||||||||||
| for quantizer in model.fc2.weight_quantizer: | ||||||||||||||||||||||||||||||||||||
| _distributed_attr_check( | ||||||||||||||||||||||||||||||||||||
| quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||
| _distributed_attr_check( | ||||||||||||||||||||||||||||||||||||
| model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks | ||||||||||||||||||||||||||||||||||||
| # It is different across DP/CP ranks since the input is different | ||||||||||||||||||||||||||||||||||||
| if tp_group and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||||||||||||
| input_quantizer = model.fc1.input_quantizer | ||||||||||||||||||||||||||||||||||||
| _distributed_attr_check( | ||||||||||||||||||||||||||||||||||||
| input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[dp_group, tp_group] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Check act scale | ||||||||||||||||||||||||||||||||||||
| if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||||||||||||
| _distributed_attr_check( | ||||||||||||||||||||||||||||||||||||
| model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group] | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def auto_quantize_helper(model): | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch on this.
Can we use