Skip to content

Commit 50000dd

Browse files
committed
fix bug
Signed-off-by: Jennifer Chen <[email protected]>
1 parent a106dd9 commit 50000dd

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def sync_act_scale_across_dp(module, data_parallel_group):
619619
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
620620
torch.isnan(module.awq_lite.weight_scale)
621621
)
622-
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
622+
has_nan = torch.tensor(int(has_nan_local), device=module.awq_lite.act_scale.device)
623623
if module.parallel_state.data_parallel_group.is_initialized():
624624
dist.all_reduce(
625625
has_nan,

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
120120

121121

122122
def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
123+
quantizer_attr = getattr(quantizer, attr).clone()
123124
for group in groups:
124125
if group is not None:
125-
quantizer_attr = getattr(quantizer, attr).clone()
126126
dist.all_reduce(quantizer_attr, op=op, group=group)
127127
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
128128

@@ -137,7 +137,7 @@ def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
137137

138138
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
139139
def data_tensor_context_parallel_test_helper(
140-
model, config, mock_awq_lite, dp_group=None, tp_group=None
140+
model, config, mock_awq_lite, dp_group=None, tp_group=None, test_pre_quant_scale=True
141141
):
142142
# Calib data should be different across each DP rank
143143
dp_rank = dist.get_rank(group=dp_group)
@@ -193,7 +193,11 @@ def forward_loop(model):
193193

194194
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
195195
# It is different across DP/CP ranks since the input is different
196-
if tp_group and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
196+
if (
197+
test_pre_quant_scale
198+
and tp_group
199+
and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]
200+
):
197201
input_quantizer = model.fc1.input_quantizer
198202
_distributed_attr_check(
199203
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[dp_group, tp_group]

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def _test_parallelism_helper(
9898
tensor_model_parallel_size=1,
9999
context_parallel_size=1,
100100
use_rank_in_seed=False,
101+
test_pre_quant_scale=True,
101102
):
102103
"""
103104
Unified helper for testing different parallelism configurations.
@@ -133,6 +134,7 @@ def _test_parallelism_helper(
133134
config,
134135
dp_group=dp_group,
135136
tp_group=tp_group,
137+
test_pre_quant_scale=test_pre_quant_scale,
136138
)
137139

138140

@@ -219,6 +221,7 @@ def test_data_tensor_context_parallel(need_8_gpus, config):
219221
tensor_model_parallel_size=2,
220222
context_parallel_size=2,
221223
use_rank_in_seed=True,
224+
test_pre_quant_scale=False,
222225
),
223226
backend="nccl",
224227
)

0 commit comments

Comments
 (0)