|
23 | 23 |
|
24 | 24 | import modelopt.torch.opt as mto
|
25 | 25 | import modelopt.torch.quantization as mtq
|
| 26 | +import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite |
26 | 27 | from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm
|
27 | 28 | from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer
|
28 | 29 | from modelopt.torch.quantization.utils import is_quantized_linear
|
@@ -127,9 +128,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None
|
127 | 128 | assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
|
128 | 129 |
|
129 | 130 |
|
130 |
| -# Store the original function before patching |
131 |
| -import modelopt.torch.quantization.model_calib as model_calib_module |
132 |
| - |
133 | 131 | original_awq_lite = model_calib_module.awq_lite
|
134 | 132 |
|
135 | 133 |
|
@@ -252,38 +250,32 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
|
252 | 250 |
|
253 | 251 | # Input quantizer amax
|
254 | 252 | if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
|
255 |
| - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) |
256 |
| - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) |
| 253 | + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) |
| 254 | + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) |
257 | 255 |
|
258 | 256 | if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
|
259 | 257 | for quantizer in model.fc1.weight_quantizer:
|
260 |
| - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) |
| 258 | + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) |
261 | 259 | else:
|
262 |
| - _reduce_quantizer_attr( |
263 |
| - model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group |
264 |
| - ) |
| 260 | + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) |
265 | 261 |
|
266 | 262 | if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
|
267 | 263 | for quantizer in model.fc2.weight_quantizer:
|
268 |
| - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) |
| 264 | + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) |
269 | 265 | else:
|
270 |
| - _reduce_quantizer_attr( |
271 |
| - model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group |
272 |
| - ) |
| 266 | + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) |
273 | 267 |
|
274 | 268 | # Check act scale
|
275 | 269 | if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
|
276 | 270 | _reduce_quantizer_attr(
|
277 | 271 | model.fc1.awq_lite,
|
278 | 272 | "act_scale",
|
279 | 273 | dist.ReduceOp.AVG,
|
280 |
| - group=tp_group, |
281 | 274 | )
|
282 | 275 | _reduce_quantizer_attr(
|
283 | 276 | model.fc2.awq_lite,
|
284 | 277 | "act_scale",
|
285 | 278 | dist.ReduceOp.AVG,
|
286 |
| - group=tp_group, |
287 | 279 | )
|
288 | 280 |
|
289 | 281 |
|
|
0 commit comments