Skip to content

Commit c6c9905

Browse files
authored
[OMNIML-2244] Create the nvfp4 quant exporter (NVIDIA#636)
## What does this PR do? **Type of change:** New feature **Overview:** - Implemented the NVFP4QuantExporter - Deprecated fp4qdq_to_2dq - Updated tests ## Usage ```python python torch_quant_to_onnx.py --quantize_mode=nvfp4 \ --onnx_save_path=vit_base_patch16_224.nvfp4.onnx \ --calibration_data_size 64 \ --batch_size 128 ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ``` python evaluate.py --onnx_path=vit_base_patch16_224.nvfp4.onnx \ --model_name=vit_base_patch16_224 \ --results_path=./results.txt \ --batch_size 128 ``` Results: ``` The top1 accuracy of the model is 84.39% The top5 accuracy of the model is 97.312% Inference latency of the model is 7.22412 ms ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: No - Deprecated fp4qdq_to_2dq - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> Signed-off-by: ajrasane <[email protected]>
1 parent 097037d commit c6c9905

File tree

7 files changed

+394
-287
lines changed

7 files changed

+394
-287
lines changed

examples/diffusers/quantization/onnx_utils/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from diffusers.models.unets import UNet2DConditionModel
4848
from torch.onnx import export as onnx_export
4949

50-
from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq
50+
from modelopt.onnx.export import NVFP4QuantExporter
5151
from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers
5252
from modelopt.torch.utils import torch_to
5353

@@ -547,6 +547,6 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
547547
else:
548548
flux_convert_rope_weight_type(onnx_model)
549549
if precision == "fp4":
550-
onnx_model = fp4qdq_to_2dq(onnx_model)
550+
onnx_model = NVFP4QuantExporter.process_model(onnx_model)
551551
save_onnx(onnx_model, q_output)
552552
shutil.rmtree(tmp_subfolder, ignore_errors=True)

examples/onnx_ptq/llm_export.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@
3030
from transformers import AutoConfig, AutoTokenizer
3131

3232
import modelopt
33-
from modelopt.onnx.export import INT4QuantExporter
33+
from modelopt.onnx.export import INT4QuantExporter, NVFP4QuantExporter
3434
from modelopt.onnx.llm_export_utils.export_utils import (
3535
ModelLoader,
3636
WrapperModelForCausalLM,
3737
llm_to_onnx,
3838
)
3939
from modelopt.onnx.llm_export_utils.quantization_utils import quantize
4040
from modelopt.onnx.llm_export_utils.surgeon_utils import fold_fp8_qdq_to_dq
41-
from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq
4241
from modelopt.torch.export import export_hf_checkpoint
4342
from modelopt.torch.quantization.utils import is_quantized_linear
4443

@@ -275,7 +274,7 @@ def time_operation(operation_name):
275274

276275
if dtype == "nvfp4":
277276
with time_operation("quantizing weights to nvfp4"):
278-
onnx_model = fp4qdq_to_2dq(onnx_model, verbose=True)
277+
onnx_model = NVFP4QuantExporter.process_model(onnx_model)
279278

280279
elif dtype == "int4_awq":
281280
with time_operation("quantizing weights to int4"):

0 commit comments

Comments
 (0)