Skip to content

Commit 93bfd52

Browse files
committed
fix tests
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 3f857a3 commit 93bfd52

File tree

4 files changed

+87
-97
lines changed

4 files changed

+87
-97
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def sync_quantizer_amax_across_tp(
114114
axes_for_sync: list,
115115
parallel_state: ParallelState,
116116
):
117+
# Syncing amax across TP for sequential quantizer
117118
if isinstance(quantizer, SequentialQuantizer):
118119
for _q in quantizer:
119-
"Syncing amax across TP for sequential quantizer"
120120
sync_quantizer_amax_across_tp(
121121
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
122122
)
@@ -616,9 +616,18 @@ def sync_act_scale_across_dp(module, data_parallel_group):
616616
module._if_calib = True
617617
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps
618618

619-
if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
619+
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
620620
torch.isnan(module.awq_lite.weight_scale)
621-
):
621+
)
622+
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
623+
if module.parallel_state.data_parallel_group.is_initialized():
624+
dist.all_reduce(
625+
has_nan,
626+
op=dist.ReduceOp.MAX,
627+
group=module.parallel_state.data_parallel_group.group,
628+
)
629+
630+
if has_nan.item() > 0:
622631
module.awq_lite.is_enabled = False
623632
else:
624633
sync_act_scale_across_dp(

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@
8383

8484

8585
class MegatronModel(MegatronModule):
86-
def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False):
86+
def __init__(
87+
self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False, tp_group=None
88+
):
8789
config = TransformerConfig(
8890
tensor_model_parallel_size=tp_size,
8991
context_parallel_size=cp_size,
@@ -104,6 +106,7 @@ def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False
104106
gather_output=False,
105107
skip_bias_add=True,
106108
is_expert=False,
109+
tp_group=tp_group,
107110
)
108111
self.activation = nn.ReLU()
109112
if use_te_norm:
@@ -118,6 +121,7 @@ def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False
118121
skip_bias_add=True,
119122
input_is_parallel=True,
120123
is_expert=False,
124+
tp_group=tp_group,
121125
)
122126

123127
def forward(self, x):

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 59 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,20 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
119119
mto.restore_from_modelopt_state(model_ref, state_dict)
120120

121121

122-
def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
123-
quantizer_attr = getattr(quantizer, attr).clone()
124-
print("quantizer.attr before reduce", getattr(quantizer, attr))
125-
dist.all_reduce(quantizer_attr, op=op, group=group)
126-
print("quantizer.attr after reduce", getattr(quantizer, attr))
127-
print("quantizer_attr after reduce", quantizer_attr)
122+
def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]):
123+
for group in groups:
124+
if group is not None:
125+
quantizer_attr = getattr(quantizer, attr).clone()
126+
dist.all_reduce(quantizer_attr, op=op, group=group)
128127
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
129128

130129

131130
original_awq_lite = model_calib_module.awq_lite
132131

133132

134-
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
133+
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
135134
"""Function to mock awq_lite function to always use debug=True for testing"""
136-
return original_awq_lite(model, forward_loop, alpha_step, debug=True)
135+
return original_awq_lite(model, forward_loop, alpha_step, debug=True, **kwargs)
137136

138137

139138
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
@@ -151,125 +150,101 @@ def forward_loop(model):
151150

152151
if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]:
153152
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks
154-
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group)
153+
_distributed_attr_check(
154+
model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[tp_group]
155+
)
155156
# Lets check the row parallel weight amax; it should be the same across all tp ranks
156-
_reduce_quantizer_attr(
157-
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group
157+
_distributed_attr_check(
158+
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[tp_group]
158159
)
159160

160161
if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
161162
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
162163
input_quantizer = model.fc1.input_quantizer
163-
_reduce_quantizer_attr(
164-
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group
164+
_distributed_attr_check(
165+
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[tp_group]
165166
)
166167

167168
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
168169
# Check activation scale for AWQ lite
169-
_reduce_quantizer_attr(
170+
_distributed_attr_check(
170171
model.fc1.awq_lite,
171172
"act_scale",
172173
dist.ReduceOp.AVG,
173-
group=tp_group,
174+
groups=[tp_group],
174175
)
175176

176177
dist.destroy_process_group()
177178

178179

179180
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
180-
def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite):
181-
calib_data = model.get_dummy_input().cuda()
182-
183-
def forward_loop(model):
184-
model(calib_data)
185-
186-
model = mtq.quantize(model, config, forward_loop)
187-
188-
# Sanity check
189-
forward_loop(model)
190-
191-
# Input quantizer amax
192-
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
193-
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
194-
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
195-
196-
# Weight quantizer amax
197-
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
198-
for quantizer in model.fc1.weight_quantizer:
199-
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
200-
else:
201-
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)
202-
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
203-
for quantizer in model.fc2.weight_quantizer:
204-
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
205-
else:
206-
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)
207-
208-
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
209-
# Check act scale
210-
_reduce_quantizer_attr(
211-
model.fc1.awq_lite,
212-
"act_scale",
213-
dist.ReduceOp.AVG,
214-
group=group,
215-
)
216-
_reduce_quantizer_attr(
217-
model.fc2.awq_lite,
218-
"act_scale",
219-
dist.ReduceOp.AVG,
220-
group=group,
221-
)
222-
223-
224-
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
225-
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite):
226-
# Calib data should be same across each DP rank
181+
def data_tensor_context_parallel_test_helper(
182+
model, config, mock_awq_lite, dp_group=None, tp_group=None
183+
):
184+
# Calib data should be different across each DP rank
227185
dp_rank = dist.get_rank(group=dp_group)
228186
calib_data = model.get_dummy_input(seed=dp_rank).cuda()
229187

188+
if tp_group is not None:
189+
# The input to first layer, the column parallel should be the same across all tp ranks
190+
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
191+
230192
def forward_loop(model):
231193
model(calib_data)
232194

233195
model = mtq.quantize(model, config, forward_loop)
234196

235-
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
236-
quantizer_attr = getattr(quantizer, attr).clone()
237-
238-
# Perform all-reduce operations
239-
dist.all_reduce(quantizer_attr, op=op, group=tp_group)
240-
241-
dist.all_reduce(quantizer_attr, op=op, group=dp_group)
242-
243-
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr)
244-
245197
# Input quantizer amax
246198
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
247-
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX)
248-
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX)
199+
_distributed_attr_check(
200+
model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
201+
)
202+
_distributed_attr_check(
203+
model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
204+
)
249205

250206
# Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
251207
# Channel-wise (INT8) only expects same amax across row parallel ranks
252208
# Block-wise quantization does not expect same amax across row and column parallel ranks
253209
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]:
254210
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
255211
for quantizer in model.fc1.weight_quantizer:
256-
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
212+
_distributed_attr_check(
213+
quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
214+
)
257215
else:
258-
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX)
259-
260-
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]:
216+
_distributed_attr_check(
217+
model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
218+
)
219+
220+
if config in [
221+
mtq.FP8_DEFAULT_CFG,
222+
mtq.NVFP4_DEFAULT_CFG,
223+
mtq.INT8_DEFAULT_CFG,
224+
mtq.INT8_SMOOTHQUANT_CFG,
225+
]:
261226
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
262227
for quantizer in model.fc2.weight_quantizer:
263-
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
228+
_distributed_attr_check(
229+
quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
230+
)
264231
else:
265-
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX)
232+
_distributed_attr_check(
233+
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
234+
)
235+
236+
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
237+
# It is different across DP/CP ranks since the input is different
238+
if tp_group and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
239+
input_quantizer = model.fc1.input_quantizer
240+
_distributed_attr_check(
241+
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[dp_group, tp_group]
242+
)
266243

267244
# Check act scale
268245
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
269-
_reduce_quantizer_attr(
270-
model.fc1.awq_lite,
271-
"act_scale",
272-
dist.ReduceOp.AVG,
246+
_distributed_attr_check(
247+
model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group]
273248
)
274249

275250

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
from _test_utils.torch_quantization.quantize_common import (
3333
auto_quantize_helper,
3434
data_tensor_context_parallel_test_helper,
35-
dp_cp_parallel_test_helper,
36-
tensor_parallel_test_helper,
3735
)
3836
from packaging.version import Version
3937

@@ -97,9 +95,10 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
9795
# 1. Tensor Parallel Test
9896
def _test_tensor_parallel_helper(config, rank, size):
9997
initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED)
100-
model = MegatronModel(tp_size=size).cuda()
98+
tp_group = get_tensor_model_parallel_group()
99+
model = MegatronModel(tp_size=size, tp_group=tp_group).cuda()
101100

102-
tensor_parallel_test_helper(model, config, get_tensor_model_parallel_group())
101+
data_tensor_context_parallel_test_helper(model, config, tp_group=tp_group)
103102

104103

105104
@pytest.mark.parametrize(
@@ -125,7 +124,7 @@ def _test_data_parallel_helper(config, rank, size):
125124
initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks
126125
model = MegatronModel().cuda()
127126

128-
dp_cp_parallel_test_helper(model, config, get_data_parallel_group())
127+
data_tensor_context_parallel_test_helper(model, config, dp_group=get_data_parallel_group())
129128

130129

131130
@pytest.mark.parametrize(
@@ -151,7 +150,9 @@ def _test_context_parallel_helper(config, rank, size):
151150
) # modify seed so data is different across ranks
152151
model = MegatronModel(cp_size=size).cuda()
153152

154-
dp_cp_parallel_test_helper(model, config, get_data_parallel_group(with_context_parallel=True))
153+
data_tensor_context_parallel_test_helper(
154+
model, config, dp_group=get_data_parallel_group(with_context_parallel=True)
155+
)
155156

156157

157158
@pytest.mark.parametrize(
@@ -175,13 +176,14 @@ def test_context_parallel(need_2_gpus, config):
175176
# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
176177
def _test_data_tensor_context_parallel_helper(config, rank, size):
177178
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED + rank)
178-
model = MegatronModel(tp_size=2, cp_size=2).cuda()
179+
tp_group = get_tensor_model_parallel_group()
180+
model = MegatronModel(tp_size=2, cp_size=2, tp_group=tp_group).cuda()
179181

180182
data_tensor_context_parallel_test_helper(
181183
model,
182184
config,
183-
get_data_parallel_group(with_context_parallel=True),
184-
get_tensor_model_parallel_group(),
185+
dp_group=get_data_parallel_group(with_context_parallel=True),
186+
tp_group=tp_group,
185187
)
186188

187189

0 commit comments

Comments
 (0)