|
60 | 60 | trt = MagicMock() |
61 | 61 | HAVE_TENSORRT = False |
62 | 62 |
|
63 | | -try: |
64 | | - from nemo.collections.llm.modelopt.quantization.quant_cfg_choices import ( |
65 | | - get_quant_cfg_choices, |
66 | | - ) |
67 | | - |
68 | | - QUANT_CFG_CHOICES = get_quant_cfg_choices() |
69 | | - |
70 | | - HAVE_NEMO = True |
71 | | -except (ImportError, ModuleNotFoundError): |
72 | | - HAVE_NEMO = False |
73 | | - |
74 | 63 |
|
75 | 64 | @wrapt.decorator |
76 | 65 | def noop_decorator(func): |
@@ -491,17 +480,15 @@ def quantize( |
491 | 480 | forward_loop (callable): A function that accepts the model as a single parameter |
492 | 481 | and runs sample data through it. This is used for calibration during quantization. |
493 | 482 | """ |
494 | | - if not HAVE_NEMO: |
495 | | - raise UnavailableError(MISSING_NEMO_MSG) |
496 | | - |
497 | 483 | if not HAVE_MODELOPT: |
498 | 484 | raise UnavailableError(MISSING_MODELOPT_MSG) |
499 | 485 |
|
| 486 | + quant_cfg_choices = get_quant_cfg_choices() |
500 | 487 | if isinstance(quant_cfg, str): |
501 | | - assert quant_cfg in QUANT_CFG_CHOICES, ( |
502 | | - f"Quantization config {quant_cfg} is not supported. Supported configs: {list(QUANT_CFG_CHOICES)}" |
| 488 | + assert quant_cfg in quant_cfg_choices, ( |
| 489 | + f"Quantization config {quant_cfg} is not supported. Supported configs: {list(quant_cfg_choices)}" |
503 | 490 | ) |
504 | | - quant_cfg = QUANT_CFG_CHOICES[quant_cfg] |
| 491 | + quant_cfg = quant_cfg_choices[quant_cfg] |
505 | 492 |
|
506 | 493 | logger.info("Starting quantization...") |
507 | 494 | mtq.quantize(self.model, quant_cfg, forward_loop=forward_loop) |
@@ -558,3 +545,35 @@ def get_calib_data_iter( |
558 | 545 | for j in range(len(batch)): |
559 | 546 | batch[j] = batch[j][:max_sequence_length] |
560 | 547 | yield batch |
| 548 | + |
| 549 | + |
| 550 | +def get_quant_cfg_choices() -> Dict[str, Dict[str, Any]]: |
| 551 | + """ |
| 552 | + Retrieve a dictionary of modelopt quantization configuration choices. |
| 553 | +
|
| 554 | + This function checks for the availability of specific quantization configurations defined in |
| 555 | + the modelopt.torch.quantization (mtq) module and returns a dictionary mapping short names to |
| 556 | + their corresponding configurations. The function is intended to work for different modelopt |
| 557 | + library versions that come with variable configuration choices. |
| 558 | +
|
| 559 | + Returns: |
| 560 | + dict: A dictionary where keys are short names (e.g., "fp8") and values are the |
| 561 | + corresponding modelopt quantization configuration objects. |
| 562 | + """ |
| 563 | + quant_cfg_names = [ |
| 564 | + ("int8", "INT8_DEFAULT_CFG"), |
| 565 | + ("int8_sq", "INT8_SMOOTHQUANT_CFG"), |
| 566 | + ("fp8", "FP8_DEFAULT_CFG"), |
| 567 | + ("block_fp8", "FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG"), |
| 568 | + ("int4_awq", "INT4_AWQ_CFG"), |
| 569 | + ("w4a8_awq", "W4A8_AWQ_BETA_CFG"), |
| 570 | + ("int4", "INT4_BLOCKWISE_WEIGHT_ONLY_CFG"), |
| 571 | + ("nvfp4", "NVFP4_DEFAULT_CFG"), |
| 572 | + ] |
| 573 | + |
| 574 | + quant_cfg_choices = {} |
| 575 | + for short_name, full_name in quant_cfg_names: |
| 576 | + if config := getattr(mtq, full_name, None): |
| 577 | + quant_cfg_choices[short_name] = config |
| 578 | + |
| 579 | + return quant_cfg_choices |
0 commit comments