|
23 | 23 |
|
24 | 24 | import modelopt.torch.opt as mto |
25 | 25 | import modelopt.torch.quantization as mtq |
| 26 | +import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite |
26 | 27 | from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm |
27 | 28 | from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer |
28 | 29 | from modelopt.torch.quantization.utils import is_quantized_linear |
@@ -127,9 +128,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None |
127 | 128 | assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) |
128 | 129 |
|
129 | 130 |
|
130 | | -# Store the original function before patching |
131 | | -import modelopt.torch.quantization.model_calib as model_calib_module |
132 | | - |
133 | 131 | original_awq_lite = model_calib_module.awq_lite |
134 | 132 |
|
135 | 133 |
|
@@ -252,38 +250,32 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): |
252 | 250 |
|
253 | 251 | # Input quantizer amax |
254 | 252 | if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: |
255 | | - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) |
256 | | - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) |
| 253 | + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) |
| 254 | + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) |
257 | 255 |
|
258 | 256 | if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): |
259 | 257 | for quantizer in model.fc1.weight_quantizer: |
260 | | - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) |
| 258 | + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) |
261 | 259 | else: |
262 | | - _reduce_quantizer_attr( |
263 | | - model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group |
264 | | - ) |
| 260 | + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) |
265 | 261 |
|
266 | 262 | if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): |
267 | 263 | for quantizer in model.fc2.weight_quantizer: |
268 | | - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) |
| 264 | + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) |
269 | 265 | else: |
270 | | - _reduce_quantizer_attr( |
271 | | - model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group |
272 | | - ) |
| 266 | + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) |
273 | 267 |
|
274 | 268 | # Check act scale |
275 | 269 | if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: |
276 | 270 | _reduce_quantizer_attr( |
277 | 271 | model.fc1.awq_lite, |
278 | 272 | "act_scale", |
279 | 273 | dist.ReduceOp.AVG, |
280 | | - group=tp_group, |
281 | 274 | ) |
282 | 275 | _reduce_quantizer_attr( |
283 | 276 | model.fc2.awq_lite, |
284 | 277 | "act_scale", |
285 | 278 | dist.ReduceOp.AVG, |
286 | | - group=tp_group, |
287 | 279 | ) |
288 | 280 |
|
289 | 281 |
|
|
0 commit comments