Skip to content

Commit 5703968

Browse files
committed
Copy get_quant_cfg_choices from nemo to export-deploy
Signed-off-by: Charlie Truong <chtruong@nvidia.com>
1 parent f494492 commit 5703968

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

nemo_export/onnx_llm_exporter.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,6 @@
6060
trt = MagicMock()
6161
HAVE_TENSORRT = False
6262

63-
try:
64-
from nemo.collections.llm.modelopt.quantization.quant_cfg_choices import (
65-
get_quant_cfg_choices,
66-
)
67-
68-
QUANT_CFG_CHOICES = get_quant_cfg_choices()
69-
70-
HAVE_NEMO = True
71-
except (ImportError, ModuleNotFoundError):
72-
HAVE_NEMO = False
73-
7463

7564
@wrapt.decorator
7665
def noop_decorator(func):
@@ -491,17 +480,15 @@ def quantize(
491480
forward_loop (callable): A function that accepts the model as a single parameter
492481
and runs sample data through it. This is used for calibration during quantization.
493482
"""
494-
if not HAVE_NEMO:
495-
raise UnavailableError(MISSING_NEMO_MSG)
496-
497483
if not HAVE_MODELOPT:
498484
raise UnavailableError(MISSING_MODELOPT_MSG)
499485

486+
quant_cfg_choices = get_quant_cfg_choices()
500487
if isinstance(quant_cfg, str):
501-
assert quant_cfg in QUANT_CFG_CHOICES, (
502-
f"Quantization config {quant_cfg} is not supported. Supported configs: {list(QUANT_CFG_CHOICES)}"
488+
assert quant_cfg in quant_cfg_choices, (
489+
f"Quantization config {quant_cfg} is not supported. Supported configs: {list(quant_cfg_choices)}"
503490
)
504-
quant_cfg = QUANT_CFG_CHOICES[quant_cfg]
491+
quant_cfg = quant_cfg_choices[quant_cfg]
505492

506493
logger.info("Starting quantization...")
507494
mtq.quantize(self.model, quant_cfg, forward_loop=forward_loop)
@@ -558,3 +545,35 @@ def get_calib_data_iter(
558545
for j in range(len(batch)):
559546
batch[j] = batch[j][:max_sequence_length]
560547
yield batch
548+
549+
550+
def get_quant_cfg_choices() -> Dict[str, Dict[str, Any]]:
551+
"""
552+
Retrieve a dictionary of modelopt quantization configuration choices.
553+
554+
This function checks for the availability of specific quantization configurations defined in
555+
the modelopt.torch.quantization (mtq) module and returns a dictionary mapping short names to
556+
their corresponding configurations. The function is intended to work for different modelopt
557+
library versions that come with variable configuration choices.
558+
559+
Returns:
560+
dict: A dictionary where keys are short names (e.g., "fp8") and values are the
561+
corresponding modelopt quantization configuration objects.
562+
"""
563+
quant_cfg_names = [
564+
("int8", "INT8_DEFAULT_CFG"),
565+
("int8_sq", "INT8_SMOOTHQUANT_CFG"),
566+
("fp8", "FP8_DEFAULT_CFG"),
567+
("block_fp8", "FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG"),
568+
("int4_awq", "INT4_AWQ_CFG"),
569+
("w4a8_awq", "W4A8_AWQ_BETA_CFG"),
570+
("int4", "INT4_BLOCKWISE_WEIGHT_ONLY_CFG"),
571+
("nvfp4", "NVFP4_DEFAULT_CFG"),
572+
]
573+
574+
quant_cfg_choices = {}
575+
for short_name, full_name in quant_cfg_names:
576+
if config := getattr(mtq, full_name, None):
577+
quant_cfg_choices[short_name] = config
578+
579+
return quant_cfg_choices

0 commit comments

Comments
 (0)