Skip to content

Commit 1f7d17e

Browse files
committed
fix test
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 7cbe5b9 commit 1f7d17e

File tree

2 files changed

+42
-63
lines changed

2 files changed

+42
-63
lines changed

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 39 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import modelopt.torch.opt as mto
2424
import modelopt.torch.quantization as mtq
2525
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm
26+
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer
2627
from modelopt.torch.quantization.utils import is_quantized_linear
2728
from modelopt.torch.utils import torch_to
2829

@@ -150,56 +151,35 @@ def forward_loop(model):
150151
dist.destroy_process_group()
151152

152153

153-
def data_parallel_test_helper(model, config, dp_group):
154+
def dp_cp_parallel_test_helper(model, config, group):
154155
calib_data = model.get_dummy_input().cuda()
155156

156157
def forward_loop(model):
157158
model(calib_data)
158159

159160
model = mtq.quantize(model, config, forward_loop)
160161

161-
# Input quantizer amax
162-
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
163-
fc1_amax = model.fc1.input_quantizer.amax.clone()
164-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
165-
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
166-
fc2_amax = model.fc2.input_quantizer.amax.clone()
167-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
168-
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
169-
170-
# Weight quantizer amax
171-
fc1_amax = model.fc1.weight_quantizer.amax.clone()
172-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
173-
assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax)
174-
fc2_amax = model.fc2.weight_quantizer.amax.clone()
175-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
176-
assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax)
177-
178-
179-
def context_parallel_test_helper(model, config, cp_group):
180-
calib_data = model.get_dummy_input().cuda()
181-
182-
def forward_loop(model):
183-
model(calib_data)
184-
185-
model = mtq.quantize(model, config, forward_loop)
162+
def reduce_amax(quantizer):
163+
amax = quantizer.amax.clone()
164+
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group)
165+
assert torch.allclose(amax, quantizer.amax)
186166

187167
# Input quantizer amax
188168
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
189-
fc1_amax = model.fc1.input_quantizer.amax.clone()
190-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
191-
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
192-
fc2_amax = model.fc2.input_quantizer.amax.clone()
193-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
194-
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
169+
reduce_amax(model.fc1.input_quantizer)
170+
reduce_amax(model.fc2.input_quantizer)
195171

196172
# Weight quantizer amax
197-
fc1_weight_amax = model.fc1.weight_quantizer.amax.clone()
198-
dist.all_reduce(fc1_weight_amax, op=dist.ReduceOp.MAX, group=cp_group)
199-
assert torch.allclose(fc1_weight_amax, model.fc1.weight_quantizer.amax)
200-
fc2_weight_amax = model.fc2.weight_quantizer.amax.clone()
201-
dist.all_reduce(fc2_weight_amax, op=dist.ReduceOp.MAX, group=cp_group)
202-
assert torch.allclose(fc2_weight_amax, model.fc2.weight_quantizer.amax)
173+
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
174+
for quantizer in model.fc1.weight_quantizer:
175+
reduce_amax(quantizer)
176+
else:
177+
reduce_amax(model.fc1.weight_quantizer)
178+
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
179+
for quantizer in model.fc2.weight_quantizer:
180+
reduce_amax(quantizer)
181+
else:
182+
reduce_amax(model.fc2.weight_quantizer)
203183

204184

205185
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group):
@@ -212,29 +192,29 @@ def forward_loop(model):
212192

213193
model = mtq.quantize(model, config, forward_loop)
214194

195+
def reduce_amax(quantizer):
196+
amax = quantizer.amax.clone()
197+
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group)
198+
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group)
199+
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group)
200+
assert torch.allclose(amax, quantizer.amax)
201+
215202
# Input quantizer amax
216203
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
217-
fc1_amax = model.fc1.input_quantizer.amax.clone()
218-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group)
219-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
220-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
221-
assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax)
222-
fc2_amax = model.fc2.input_quantizer.amax.clone()
223-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group)
224-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
225-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
226-
assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax)
227-
228-
fc1_amax = model.fc1.weight_quantizer.amax.clone()
229-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group)
230-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group)
231-
dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group)
232-
assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax)
233-
fc2_amax = model.fc2.weight_quantizer.amax.clone()
234-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group)
235-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group)
236-
dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group)
237-
assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax)
204+
reduce_amax(model.fc1.input_quantizer)
205+
reduce_amax(model.fc2.input_quantizer)
206+
207+
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
208+
for quantizer in model.fc1.weight_quantizer:
209+
reduce_amax(quantizer)
210+
else:
211+
reduce_amax(model.fc1.weight_quantizer)
212+
213+
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
214+
for quantizer in model.fc2.weight_quantizer:
215+
reduce_amax(quantizer)
216+
else:
217+
reduce_amax(model.fc2.weight_quantizer)
238218

239219

240220
def auto_quantize_helper(model):

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@
3131
from _test_utils.torch_quantization.quant_utils import get_model_size
3232
from _test_utils.torch_quantization.quantize_common import (
3333
auto_quantize_helper,
34-
context_parallel_test_helper,
35-
data_parallel_test_helper,
3634
data_tensor_context_parallel_test_helper,
35+
dp_cp_parallel_test_helper,
3736
tensor_parallel_test_helper,
3837
)
3938
from packaging.version import Version
@@ -128,7 +127,7 @@ def _test_data_parallel_helper(config, rank, size):
128127
initialize_for_megatron(seed=SEED)
129128
model = MegatronModel().cuda()
130129

131-
data_parallel_test_helper(model, config, get_data_parallel_group())
130+
dp_cp_parallel_test_helper(model, config, get_data_parallel_group())
132131

133132

134133
@pytest.mark.parametrize(
@@ -152,7 +151,7 @@ def _test_context_parallel_helper(config, rank, size):
152151
initialize_for_megatron(context_parallel_size=size, seed=SEED)
153152
model = MegatronModel(cp_size=size).cuda()
154153

155-
context_parallel_test_helper(model, config, get_context_parallel_group())
154+
dp_cp_parallel_test_helper(model, config, get_context_parallel_group())
156155

157156

158157
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)