-
Notifications
You must be signed in to change notification settings - Fork 202
Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
f17131f
42519cc
264adbb
7cbe5b9
1f7d17e
71a9f7a
d02365c
5a572da
fc0bb88
95da832
34c11ef
10e3e2b
9f0691f
fa8f4c8
d1fac44
22b8b73
ca7c0e8
3f857a3
93bfd52
6761109
291cfa3
a106dd9
50000dd
2664563
440ca48
2e8ef58
5cb380c
afe6f34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we dont need separate methods |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import copy | ||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
@@ -22,7 +23,9 @@ | |
|
|
||
| import modelopt.torch.opt as mto | ||
| import modelopt.torch.quantization as mtq | ||
| import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite | ||
| from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm | ||
| from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer | ||
| from modelopt.torch.quantization.utils import is_quantized_linear | ||
| from modelopt.torch.utils import torch_to | ||
|
|
||
|
|
@@ -116,40 +119,160 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N | |
| mto.restore_from_modelopt_state(model_ref, state_dict) | ||
|
|
||
|
|
||
| def tensor_parallel_test_helper(model, config, tp_group, dp_group): | ||
| # The input to fist layer, the column parallel should be the same across all tp ranks | ||
| def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): | ||
jenchen13 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| quantizer_attr = getattr(quantizer, attr).clone() | ||
| print("quantizer.attr before reduce", getattr(quantizer, attr)) | ||
| dist.all_reduce(quantizer_attr, op=op, group=group) | ||
jenchen13 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| print("quantizer.attr after reduce", getattr(quantizer, attr)) | ||
| print("quantizer_attr after reduce", quantizer_attr) | ||
| assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) | ||
|
|
||
|
|
||
| original_awq_lite = model_calib_module.awq_lite | ||
|
|
||
|
|
||
| def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): | ||
| """Function to mock awq_lite function to always use debug=True for testing""" | ||
| return original_awq_lite(model, forward_loop, alpha_step, debug=True) | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||
| def tensor_parallel_test_helper(model, config, tp_group, mock_awq_lite): | ||
| # The input to first layer, the column parallel should be the same across all tp ranks | ||
| calib_data = model.get_dummy_input().cuda() | ||
| dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) | ||
|
|
||
| def forward_loop(model): | ||
| model(calib_data) | ||
|
|
||
| model = mtq.quantize(model, config, forward_loop) | ||
|
|
||
| # Sanity check | ||
| forward_loop(model) | ||
|
|
||
| if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: | ||
| # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks | ||
| activation_amax = model.fc2.input_quantizer.amax.clone() | ||
| dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) | ||
| assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) | ||
|
|
||
| _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group) | ||
| # Lets check the row parallel weight amax; it should be the same across all tp ranks | ||
| weight_amax = model.fc2.weight_quantizer.amax.clone() | ||
| dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) | ||
| assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) | ||
| _reduce_quantizer_attr( | ||
| model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group | ||
| ) | ||
|
|
||
| if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||
| # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks | ||
| input_quantizer = model.fc1.input_quantizer | ||
| pre_quant_scale = input_quantizer.pre_quant_scale.clone() | ||
| dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) | ||
| assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) | ||
| _reduce_quantizer_attr( | ||
| input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group | ||
| ) | ||
|
|
||
| if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||
| # Check activation scale for AWQ lite | ||
| _reduce_quantizer_attr( | ||
|
||
| model.fc1.awq_lite, | ||
| "act_scale", | ||
| dist.ReduceOp.AVG, | ||
| group=tp_group, | ||
| ) | ||
|
|
||
| dist.destroy_process_group() | ||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||
| def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite): | ||
| calib_data = model.get_dummy_input().cuda() | ||
|
|
||
| def forward_loop(model): | ||
| model(calib_data) | ||
|
|
||
| model = mtq.quantize(model, config, forward_loop) | ||
|
|
||
| # Sanity check | ||
| forward_loop(model) | ||
|
|
||
| # Input quantizer amax | ||
| if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: | ||
| _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||
| _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||
|
|
||
| # Weight quantizer amax | ||
| if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): | ||
| for quantizer in model.fc1.weight_quantizer: | ||
| _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||
| else: | ||
| _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||
| if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): | ||
| for quantizer in model.fc2.weight_quantizer: | ||
| _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||
| else: | ||
| _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||
|
|
||
| if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||
| # Check act scale | ||
| _reduce_quantizer_attr( | ||
| model.fc1.awq_lite, | ||
| "act_scale", | ||
| dist.ReduceOp.AVG, | ||
| group=group, | ||
| ) | ||
| _reduce_quantizer_attr( | ||
| model.fc2.awq_lite, | ||
| "act_scale", | ||
| dist.ReduceOp.AVG, | ||
| group=group, | ||
| ) | ||
|
|
||
|
|
||
| @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||
| def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): | ||
| # Calib data should be same across each DP rank | ||
| dp_rank = dist.get_rank(group=dp_group) | ||
| calib_data = model.get_dummy_input(seed=dp_rank).cuda() | ||
|
|
||
| def forward_loop(model): | ||
| model(calib_data) | ||
|
|
||
| model = mtq.quantize(model, config, forward_loop) | ||
|
|
||
| def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): | ||
| quantizer_attr = getattr(quantizer, attr).clone() | ||
|
|
||
| # Perform all-reduce operations | ||
| dist.all_reduce(quantizer_attr, op=op, group=tp_group) | ||
|
|
||
| dist.all_reduce(quantizer_attr, op=op, group=dp_group) | ||
|
|
||
| assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr) | ||
jenchen13 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Input quantizer amax | ||
| if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: | ||
| _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) | ||
jenchen13 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) | ||
|
|
||
| # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks | ||
| # Channel-wise (INT8) only expects same amax across row parallel ranks | ||
| # Block-wise quantization does not expect same amax across row and column parallel ranks | ||
| if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: | ||
| if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): | ||
| for quantizer in model.fc1.weight_quantizer: | ||
| _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) | ||
| else: | ||
| _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) | ||
|
|
||
| if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]: | ||
| if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): | ||
| for quantizer in model.fc2.weight_quantizer: | ||
| _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) | ||
| else: | ||
| _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) | ||
|
|
||
| # Check act scale | ||
| if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||
| _reduce_quantizer_attr( | ||
| model.fc1.awq_lite, | ||
| "act_scale", | ||
| dist.ReduceOp.AVG, | ||
| ) | ||
|
|
||
jenchen13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def auto_quantize_helper(model): | ||
| model, search_state = mtq.auto_quantize( | ||
| model, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.