|
25 | 25 | import modelopt.torch.quantization.triton as triton_kernel |
26 | 26 |
|
27 | 27 | from .config import QuantizerAttributeConfig |
28 | | -from .export_onnx import export_fp4, export_fp8, export_int8, export_mxfp8 |
29 | 28 | from .extensions import get_cuda_ext, get_cuda_ext_fp8, get_cuda_ext_mx |
30 | 29 |
|
31 | 30 | mx_format_map = { |
@@ -325,6 +324,8 @@ def symbolic( |
325 | 324 | trt_high_precision_dtype=None, |
326 | 325 | ): |
327 | 326 | """ONNX symbolic function.""" |
| 327 | + from .export_onnx import export_int8 |
| 328 | + |
328 | 329 | return export_int8( |
329 | 330 | g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype |
330 | 331 | ) |
@@ -395,6 +396,8 @@ class ScaledE4M3Function(Function): |
395 | 396 | @symbolic_helper.parse_args("v", "t", "t", "i", "i", "s") |
396 | 397 | def symbolic(g, inputs, amax=None, bias=None, E=4, M=3, trt_high_precision_dtype=None): # noqa: N803 |
397 | 398 | """ONNX symbolic function.""" |
| 399 | + from .export_onnx import export_fp8 |
| 400 | + |
398 | 401 | return export_fp8(g, inputs, amax, trt_high_precision_dtype) |
399 | 402 |
|
400 | 403 | @staticmethod |
@@ -475,6 +478,8 @@ def symbolic( |
475 | 478 | onnx_quantizer_type="dynamic", |
476 | 479 | ): |
477 | 480 | """ONNX symbolic function.""" |
| 481 | + from .export_onnx import export_fp4, export_mxfp8 |
| 482 | + |
478 | 483 | if num_bits == (2, 1) and scale_bits == (4, 3): |
479 | 484 | return export_fp4( |
480 | 485 | g, |
@@ -643,6 +648,8 @@ def symbolic( |
643 | 648 | trt_high_precision_dtype=None, |
644 | 649 | ): |
645 | 650 | """ONNX symbolic function.""" |
| 651 | + from .export_onnx import export_int8 |
| 652 | + |
646 | 653 | return export_int8( |
647 | 654 | g, inputs, amax, num_bits, unsigned, narrow_range, trt_high_precision_dtype |
648 | 655 | ) |
|
0 commit comments