Skip to content

Commit 5a572da

Browse files
committed
move awq test inside megatron tests
Signed-off-by: Jennifer Chen <[email protected]>
1 parent d02365c commit 5a572da

File tree

4 files changed

+86
-71
lines changed

4 files changed

+86
-71
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def forward(self, input, *args, **kwargs):
581581
return out_actual
582582

583583
for name, module in model.named_modules():
584+
print(name, module, module.weight_quantizer.is_enabled)
584585
if is_quantized_linear(module) and module.weight_quantizer.is_enabled:
585586
with enable_weight_access_and_writeback(module, model):
586587
module.awq_lite = AWQLiteHelper(module, name)

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
117117
mto.restore_from_modelopt_state(model_ref, state_dict)
118118

119119

120+
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
121+
quantizer_attr = getattr(quantizer, attr).clone()
122+
dist.all_reduce(quantizer_attr, op=op, group=group)
123+
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
124+
125+
120126
def tensor_parallel_test_helper(model, config, tp_group):
121127
# The input to first layer, the column parallel should be the same across all tp ranks
122128
calib_data = model.get_dummy_input().cuda()
@@ -126,27 +132,39 @@ def forward_loop(model):
126132
model(calib_data)
127133

128134
model = mtq.quantize(model, config, forward_loop)
129-
130135
# Sanity check
131136
forward_loop(model)
132137

133138
if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]:
134139
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks
135-
activation_amax = model.fc2.input_quantizer.amax.clone()
136-
dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group)
137-
assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)
140+
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group)
138141

139142
# Lets check the row parallel weight amax; it should be the same across all tp ranks
140-
weight_amax = model.fc2.weight_quantizer.amax.clone()
141-
dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group)
142-
assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax)
143+
_reduce_quantizer_attr(
144+
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group
145+
)
143146

144147
if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
145148
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
146149
input_quantizer = model.fc1.input_quantizer
147-
pre_quant_scale = input_quantizer.pre_quant_scale.clone()
148-
dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group)
149-
assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale)
150+
_reduce_quantizer_attr(
151+
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group
152+
)
153+
154+
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
155+
# Check act scale
156+
_reduce_quantizer_attr(
157+
model.fc1.weight_quantizer.awq_lite.act_scale,
158+
"act_scale",
159+
dist.ReduceOp.AVG,
160+
group=tp_group,
161+
)
162+
_reduce_quantizer_attr(
163+
model.fc2.weight_quantizer.awq_lite.act_scale,
164+
"act_scale",
165+
dist.ReduceOp.AVG,
166+
group=tp_group,
167+
)
150168

151169
dist.destroy_process_group()
152170

@@ -159,27 +177,37 @@ def forward_loop(model):
159177

160178
model = mtq.quantize(model, config, forward_loop)
161179

162-
def reduce_amax(quantizer):
163-
amax = quantizer.amax.clone()
164-
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group)
165-
assert torch.allclose(amax, quantizer.amax)
166-
167180
# Input quantizer amax
168181
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
169-
reduce_amax(model.fc1.input_quantizer)
170-
reduce_amax(model.fc2.input_quantizer)
182+
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
183+
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
171184

172185
# Weight quantizer amax
173186
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
174187
for quantizer in model.fc1.weight_quantizer:
175-
reduce_amax(quantizer)
188+
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
176189
else:
177-
reduce_amax(model.fc1.weight_quantizer)
190+
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)
178191
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
179192
for quantizer in model.fc2.weight_quantizer:
180-
reduce_amax(quantizer)
193+
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
181194
else:
182-
reduce_amax(model.fc2.weight_quantizer)
195+
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)
196+
197+
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
198+
# Check act scale
199+
_reduce_quantizer_attr(
200+
model.fc1.weight_quantizer.awq_lite.act_scale,
201+
"act_scale",
202+
dist.ReduceOp.AVG,
203+
group=group,
204+
)
205+
_reduce_quantizer_attr(
206+
model.fc2.weight_quantizer.awq_lite.act_scale,
207+
"act_scale",
208+
dist.ReduceOp.AVG,
209+
group=group,
210+
)
183211

184212

185213
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group):
@@ -192,33 +220,52 @@ def forward_loop(model):
192220

193221
model = mtq.quantize(model, config, forward_loop)
194222

195-
def reduce_amax(quantizer):
196-
amax = quantizer.amax.clone()
197-
print("amax before reduce", amax)
198-
print("quantizer.amax before reduce", quantizer.amax)
199-
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group)
200-
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group)
201-
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group)
202-
print("amax after reduce", amax)
203-
print("quantizer.amax after reduce", quantizer.amax)
204-
assert torch.allclose(amax, quantizer.amax)
223+
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
224+
quantizer_attr = getattr(quantizer, attr).clone()
225+
print("quantizer_attr before reduce", quantizer_attr)
226+
print("quantizer.attr before reduce", getattr(quantizer, attr))
227+
dist.all_reduce(quantizer_attr, op=op, group=dp_group)
228+
dist.all_reduce(quantizer_attr, op=op, group=cp_group)
229+
dist.all_reduce(quantizer_attr, op=op, group=tp_group)
230+
print("quantizer_attr after reduce", quantizer_attr)
231+
print("quantizer.attr after reduce", getattr(quantizer, attr))
232+
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
205233

206234
# Input quantizer amax
207235
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
208-
reduce_amax(model.fc1.input_quantizer)
209-
reduce_amax(model.fc2.input_quantizer)
236+
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group)
237+
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group)
210238

211239
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
212240
for quantizer in model.fc1.weight_quantizer:
213-
reduce_amax(quantizer)
241+
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group)
214242
else:
215-
reduce_amax(model.fc1.weight_quantizer)
243+
_reduce_quantizer_attr(
244+
model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group
245+
)
216246

217247
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
218248
for quantizer in model.fc2.weight_quantizer:
219-
reduce_amax(quantizer)
249+
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group)
220250
else:
221-
reduce_amax(model.fc2.weight_quantizer)
251+
_reduce_quantizer_attr(
252+
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group
253+
)
254+
255+
# Check act scale
256+
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
257+
_reduce_quantizer_attr(
258+
model.fc1.weight_quantizer.awq_lite.act_scale,
259+
"act_scale",
260+
dist.ReduceOp.AVG,
261+
group=tp_group,
262+
)
263+
_reduce_quantizer_attr(
264+
model.fc2.weight_quantizer.awq_lite.act_scale,
265+
"act_scale",
266+
dist.ReduceOp.AVG,
267+
group=tp_group,
268+
)
222269

223270

224271
def auto_quantize_helper(model):

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def test_context_parallel(need_2_gpus, config):
175175

176176
# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
177177
def _test_data_tensor_context_parallel_helper(config, rank, size):
178-
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED)
178+
initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED + rank)
179179
model = MegatronModel(tp_size=2, cp_size=2).cuda()
180180

181181
data_tensor_context_parallel_test_helper(

tests/gpu/torch/quantization/test_model_calib.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)