|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import argparse |
17 | | -import copy |
18 | 17 | import random |
19 | 18 | import time |
20 | 19 | import warnings |
|
25 | 24 | from accelerate.hooks import remove_hook_from_module |
26 | 25 | from example_utils import ( |
27 | 26 | apply_kv_cache_quant, |
| 27 | + build_quant_cfg, |
28 | 28 | copy_custom_model_files, |
29 | 29 | get_model, |
30 | 30 | get_processor, |
@@ -448,47 +448,15 @@ def main(args): |
448 | 448 | include_labels=args.auto_quantize_bits is not None, |
449 | 449 | ) |
450 | 450 |
|
451 | | - quant_cfg = {} |
452 | | - if not args.auto_quantize_bits: |
453 | | - assert args.qformat in QUANT_CFG_CHOICES, ( |
454 | | - f"Unsupported quantization format: {args.qformat} with {args.kv_cache_qformat} KV cache" |
455 | | - ) |
456 | | - |
457 | | - quant_cfg = QUANT_CFG_CHOICES[args.qformat] |
458 | | - |
459 | | - if "awq" in args.qformat: |
460 | | - quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat]) |
461 | | - weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] |
462 | | - if isinstance(weight_quantizer, list): |
463 | | - weight_quantizer = weight_quantizer[0] |
464 | | - # If awq_block_size argument is provided, update weight_quantizer |
465 | | - if args.awq_block_size: |
466 | | - weight_quantizer["block_sizes"][-1] = args.awq_block_size |
467 | | - |
468 | | - # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models |
469 | | - if args.qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: |
470 | | - quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} |
471 | | - |
472 | | - enable_quant_kv_cache = args.kv_cache_qformat != "none" |
473 | | - print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization") |
474 | | - |
475 | | - # Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer. |
476 | | - if enable_quant_kv_cache: |
477 | | - quant_cfg = apply_kv_cache_quant( |
478 | | - quant_cfg, |
479 | | - getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"], |
480 | | - ) |
481 | | - |
482 | | - # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. |
483 | | - if model_type == "gemma" and "int8_sq" in args.qformat: |
484 | | - quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} |
485 | | - |
486 | | - if model_type == "phi4mm": |
487 | | - # Only quantize the language model |
488 | | - quant_cfg["quant_cfg"]["*speech*"] = {"enable": False} |
489 | | - quant_cfg["quant_cfg"]["*audio*"] = {"enable": False} |
490 | | - quant_cfg["quant_cfg"]["*image*"] = {"enable": False} |
491 | | - quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} |
| 451 | + quant_cfg = build_quant_cfg( |
| 452 | + args.qformat, |
| 453 | + args.kv_cache_qformat, |
| 454 | + args.awq_block_size, |
| 455 | + args.auto_quantize_bits, |
| 456 | + model_type, |
| 457 | + QUANT_CFG_CHOICES, |
| 458 | + KV_QUANT_CFG_CHOICES, |
| 459 | + ) |
492 | 460 |
|
493 | 461 | if not model_is_already_quantized or calibration_only: |
494 | 462 | # Only run single sample for preview |
|
0 commit comments