-
Notifications
You must be signed in to change notification settings - Fork 169
Sync amax & AWQ-Lite act_scale in context parallel/data parallel [OMNIML-2813] #359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 17 commits
f17131f
42519cc
264adbb
7cbe5b9
1f7d17e
71a9f7a
d02365c
5a572da
fc0bb88
95da832
34c11ef
10e3e2b
9f0691f
fa8f4c8
d1fac44
22b8b73
ca7c0e8
3f857a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,6 +13,7 @@ | |||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||
import copy | ||||||||||||||||||||||||||
from unittest.mock import patch | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import pytest | ||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||
|
@@ -22,7 +23,9 @@ | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import modelopt.torch.opt as mto | ||||||||||||||||||||||||||
import modelopt.torch.quantization as mtq | ||||||||||||||||||||||||||
import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite | ||||||||||||||||||||||||||
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm | ||||||||||||||||||||||||||
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer | ||||||||||||||||||||||||||
from modelopt.torch.quantization.utils import is_quantized_linear | ||||||||||||||||||||||||||
from modelopt.torch.utils import torch_to | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -116,40 +119,160 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N | |||||||||||||||||||||||||
mto.restore_from_modelopt_state(model_ref, state_dict) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def tensor_parallel_test_helper(model, config, tp_group, dp_group): | ||||||||||||||||||||||||||
# The input to fist layer, the column parallel should be the same across all tp ranks | ||||||||||||||||||||||||||
def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): | ||||||||||||||||||||||||||
quantizer_attr = getattr(quantizer, attr).clone() | ||||||||||||||||||||||||||
print("quantizer.attr before reduce", getattr(quantizer, attr)) | ||||||||||||||||||||||||||
dist.all_reduce(quantizer_attr, op=op, group=group) | ||||||||||||||||||||||||||
print("quantizer.attr after reduce", getattr(quantizer, attr)) | ||||||||||||||||||||||||||
print("quantizer_attr after reduce", quantizer_attr) | ||||||||||||||||||||||||||
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
original_awq_lite = model_calib_module.awq_lite | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): | ||||||||||||||||||||||||||
"""Function to mock awq_lite function to always use debug=True for testing""" | ||||||||||||||||||||||||||
return original_awq_lite(model, forward_loop, alpha_step, debug=True) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Comment on lines
+134
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forward the AWQ-Lite kwargs in the patch The -def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
- """Function to mock awq_lite function to always use debug=True for testing"""
- return original_awq_lite(model, forward_loop, alpha_step, debug=True)
+def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
+ """Force awq_lite debug mode during tests without dropping optional args."""
+ return original_awq_lite(
+ model,
+ forward_loop,
+ alpha_step=alpha_step,
+ debug=True,
+ **kwargs,
+ ) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||||||||||||||||||||||||||
def tensor_parallel_test_helper(model, config, tp_group, mock_awq_lite): | ||||||||||||||||||||||||||
# The input to first layer, the column parallel should be the same across all tp ranks | ||||||||||||||||||||||||||
calib_data = model.get_dummy_input().cuda() | ||||||||||||||||||||||||||
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def forward_loop(model): | ||||||||||||||||||||||||||
model(calib_data) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
model = mtq.quantize(model, config, forward_loop) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Sanity check | ||||||||||||||||||||||||||
forward_loop(model) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: | ||||||||||||||||||||||||||
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks | ||||||||||||||||||||||||||
activation_amax = model.fc2.input_quantizer.amax.clone() | ||||||||||||||||||||||||||
dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||
assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||
# Lets check the row parallel weight amax; it should be the same across all tp ranks | ||||||||||||||||||||||||||
weight_amax = model.fc2.weight_quantizer.amax.clone() | ||||||||||||||||||||||||||
dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||
assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks | ||||||||||||||||||||||||||
input_quantizer = model.fc1.input_quantizer | ||||||||||||||||||||||||||
pre_quant_scale = input_quantizer.pre_quant_scale.clone() | ||||||||||||||||||||||||||
dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) | ||||||||||||||||||||||||||
assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||
# Check activation scale for AWQ lite | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @realAsma For TP, I only test fc1 (column parallel) act scale during awq lite, because fc2 row parallel will fail. For DP/CP I can test both column + row parallel act scale. I'm assuming row parallel fails because it's split across the |
||||||||||||||||||||||||||
model.fc1.awq_lite, | ||||||||||||||||||||||||||
"act_scale", | ||||||||||||||||||||||||||
dist.ReduceOp.AVG, | ||||||||||||||||||||||||||
group=tp_group, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
dist.destroy_process_group() | ||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||||||||||||||||||||||||||
def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite): | ||||||||||||||||||||||||||
calib_data = model.get_dummy_input().cuda() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def forward_loop(model): | ||||||||||||||||||||||||||
model(calib_data) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
model = mtq.quantize(model, config, forward_loop) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Sanity check | ||||||||||||||||||||||||||
forward_loop(model) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Input quantizer amax | ||||||||||||||||||||||||||
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Weight quantizer amax | ||||||||||||||||||||||||||
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||
for quantizer in model.fc1.weight_quantizer: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||
for quantizer in model.fc2.weight_quantizer: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||
# Check act scale | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
model.fc1.awq_lite, | ||||||||||||||||||||||||||
"act_scale", | ||||||||||||||||||||||||||
dist.ReduceOp.AVG, | ||||||||||||||||||||||||||
group=group, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
model.fc2.awq_lite, | ||||||||||||||||||||||||||
"act_scale", | ||||||||||||||||||||||||||
dist.ReduceOp.AVG, | ||||||||||||||||||||||||||
group=group, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) | ||||||||||||||||||||||||||
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): | ||||||||||||||||||||||||||
# Calib data should be same across each DP rank | ||||||||||||||||||||||||||
dp_rank = dist.get_rank(group=dp_group) | ||||||||||||||||||||||||||
calib_data = model.get_dummy_input(seed=dp_rank).cuda() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def forward_loop(model): | ||||||||||||||||||||||||||
model(calib_data) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
model = mtq.quantize(model, config, forward_loop) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): | ||||||||||||||||||||||||||
quantizer_attr = getattr(quantizer, attr).clone() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Perform all-reduce operations | ||||||||||||||||||||||||||
dist.all_reduce(quantizer_attr, op=op, group=tp_group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
dist.all_reduce(quantizer_attr, op=op, group=dp_group) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Input quantizer amax | ||||||||||||||||||||||||||
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks | ||||||||||||||||||||||||||
# Channel-wise (INT8) only expects same amax across row parallel ranks | ||||||||||||||||||||||||||
# Block-wise quantization does not expect same amax across row and column parallel ranks | ||||||||||||||||||||||||||
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: | ||||||||||||||||||||||||||
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||
for quantizer in model.fc1.weight_quantizer: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]: | ||||||||||||||||||||||||||
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): | ||||||||||||||||||||||||||
for quantizer in model.fc2.weight_quantizer: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Check act scale | ||||||||||||||||||||||||||
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: | ||||||||||||||||||||||||||
_reduce_quantizer_attr( | ||||||||||||||||||||||||||
model.fc1.awq_lite, | ||||||||||||||||||||||||||
"act_scale", | ||||||||||||||||||||||||||
dist.ReduceOp.AVG, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
jenchen13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def auto_quantize_helper(model): | ||||||||||||||||||||||||||
model, search_state = mtq.auto_quantize( | ||||||||||||||||||||||||||
model, | ||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify: NaN check placement may cause collective operation mismatch.
The current implementation places the sync call in the
else
block (lines 624-627), which means:is_enabled = False
, skips else block, does NOT callsync_act_scale_across_dp
sync_act_scale_across_dp
Since
dist.all_reduce
(insidesync_act_scale_across_dp
at line 606) is a collective operation requiring all ranks to participate, this creates a deadlock when ranks disagree on NaN presence.Past review comments suggest this should be resolved by either:
Consider adding a collective NaN check before the conditional:
Based on past review comments.
🏁 Script executed:
Length of output: 1217
🏁 Script executed:
Length of output: 3165
Collectively sync NaN detection before calling
sync_act_scale_across_dp
Perform a global NaN check across all ranks, then either disable calibration everywhere or call the collective sync together. For example, replace the per-rank branch (model_calib.py ~621–627) with:
This ensures every rank participates in the collective operation and prevents deadlock.
🤖 Prompt for AI Agents