1313# limitations under the License.
1414
1515
16+ import contextlib
1617import logging
1718import warnings
1819from pathlib import Path
4041
4142try :
4243 import modelopt .torch .quantization as mtq
44+ from modelopt .torch .quantization .export_onnx import configure_linear_module_onnx_quantizers
4345
4446 HAVE_MODELOPT = True
4547except (ImportError , ModuleNotFoundError ):
4648 from unittest .mock import MagicMock
4749
4850 mtq = MagicMock ()
51+ configure_linear_module_onnx_quantizers = MagicMock ()
4952 HAVE_MODELOPT = False
5053
5154
@@ -135,6 +138,8 @@ def __init__(
135138 self .model_output_names = None
136139 self .onnx_runtime_session = None
137140 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
141+ self ._is_quantized = False
142+ self ._quant_cfg_short_name = ""
138143
139144 if self .model_name_or_path is not None :
140145 if model is not None :
@@ -229,7 +234,11 @@ def _export_to_onnx(
229234
230235 Path (self .onnx_model_dir ).mkdir (parents = True , exist_ok = True )
231236
232- with torch .autocast (device_type = get_model_device_type (self .model ), dtype = export_dtype ):
237+ quantizer_context = contextlib .nullcontext ()
238+ if self ._is_quantized and (self ._quant_cfg_short_name == "nvfp4" or self ._quant_cfg_short_name == "mxfp8" ):
239+ quantizer_context = configure_linear_module_onnx_quantizers (self .model )
240+
241+ with quantizer_context , torch .autocast (device_type = get_model_device_type (self .model ), dtype = export_dtype ):
233242 torch .onnx .export (
234243 model = self .model ,
235244 args = (example_inputs ,),
@@ -239,6 +248,7 @@ def _export_to_onnx(
239248 dynamic_axes = {** dynamic_axes_input , ** dynamic_axes_output },
240249 verbose = verbose ,
241250 opset_version = opset ,
251+ dynamo = not self ._is_quantized ,
242252 )
243253 logger .info (f"Successfully exported PyTorch model to ONNX model { self .onnx_model_path } " )
244254
@@ -484,13 +494,15 @@ def quantize(
484494
485495 quant_cfg_choices = get_quant_cfg_choices ()
486496 if isinstance (quant_cfg , str ):
497+ self ._quant_cfg_short_name = quant_cfg
487498 assert quant_cfg in quant_cfg_choices , (
488499 f"Quantization config { quant_cfg } is not supported. Supported configs: { list (quant_cfg_choices )} "
489500 )
490501 quant_cfg = quant_cfg_choices [quant_cfg ]
491502
492503 logger .info ("Starting quantization..." )
493504 mtq .quantize (self .model , quant_cfg , forward_loop = forward_loop )
505+ self ._is_quantized = True
494506 logger .info ("Quantization is completed." )
495507
496508 @property
0 commit comments