Skip to content

Commit c119095

Browse files
committed
Fix onnx quantize
Signed-off-by: Charlie Truong <chtruong@nvidia.com>
1 parent c8d23e1 commit c119095

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

nemo_export/onnx_llm_exporter.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515

16+
import contextlib
1617
import logging
1718
import warnings
1819
from pathlib import Path
@@ -40,12 +41,14 @@
4041

4142
try:
4243
import modelopt.torch.quantization as mtq
44+
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
4345

4446
HAVE_MODELOPT = True
4547
except (ImportError, ModuleNotFoundError):
4648
from unittest.mock import MagicMock
4749

4850
mtq = MagicMock()
51+
configure_linear_module_onnx_quantizers = MagicMock()
4952
HAVE_MODELOPT = False
5053

5154

@@ -135,6 +138,8 @@ def __init__(
135138
self.model_output_names = None
136139
self.onnx_runtime_session = None
137140
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141+
self._is_quantized = False
142+
self._quant_cfg_short_name = ""
138143

139144
if self.model_name_or_path is not None:
140145
if model is not None:
@@ -229,7 +234,11 @@ def _export_to_onnx(
229234

230235
Path(self.onnx_model_dir).mkdir(parents=True, exist_ok=True)
231236

232-
with torch.autocast(device_type=get_model_device_type(self.model), dtype=export_dtype):
237+
quantizer_context = contextlib.nullcontext()
238+
if self._is_quantized and (self._quant_cfg_short_name == "nvfp4" or self._quant_cfg_short_name == "mxfp8"):
239+
quantizer_context = configure_linear_module_onnx_quantizers(self.model)
240+
241+
with quantizer_context, torch.autocast(device_type=get_model_device_type(self.model), dtype=export_dtype):
233242
torch.onnx.export(
234243
model=self.model,
235244
args=(example_inputs,),
@@ -239,6 +248,7 @@ def _export_to_onnx(
239248
dynamic_axes={**dynamic_axes_input, **dynamic_axes_output},
240249
verbose=verbose,
241250
opset_version=opset,
251+
dynamo=not self._is_quantized,
242252
)
243253
logger.info(f"Successfully exported PyTorch model to ONNX model {self.onnx_model_path}")
244254

@@ -484,13 +494,15 @@ def quantize(
484494

485495
quant_cfg_choices = get_quant_cfg_choices()
486496
if isinstance(quant_cfg, str):
497+
self._quant_cfg_short_name = quant_cfg
487498
assert quant_cfg in quant_cfg_choices, (
488499
f"Quantization config {quant_cfg} is not supported. Supported configs: {list(quant_cfg_choices)}"
489500
)
490501
quant_cfg = quant_cfg_choices[quant_cfg]
491502

492503
logger.info("Starting quantization...")
493504
mtq.quantize(self.model, quant_cfg, forward_loop=forward_loop)
505+
self._is_quantized = True
494506
logger.info("Quantization is completed.")
495507

496508
@property

0 commit comments

Comments
 (0)