diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index dbf6e27fa..a6b420b92 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -50,6 +50,7 @@ modelopt/torch/utils @NVIDIA/modelopt-torch-utils-codeowners /examples/pruning @NVIDIA/modelopt-torch-nas-prune-codeowners /examples/specdec_bench @NVIDIA/modelopt-torch-speculative-codeowners /examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners +/examples/torch_onnx @NVIDIA/modelopt-onnx-codeowners /examples/vlm_ptq @NVIDIA/modelopt-examples-vlm-codeowners /examples/vllm_serve @NVIDIA/modelopt-examples-llm_ptq-codeowners /examples/windows @NVIDIA/modelopt-windows-codeowners diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index d0d23df68..7db96d11e 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -123,7 +123,7 @@ jobs: strategy: fail-fast: false matrix: - example: [diffusers, onnx_ptq] + example: [diffusers, torch_onnx] uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: @@ -137,7 +137,7 @@ jobs: strategy: fail-fast: false matrix: - example: [diffusers, onnx_ptq] + example: [diffusers, torch_onnx] uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: diff --git a/README.md b/README.md index 21506c4ca..7e956e036 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ more fine-grained control on installed dependencies or for alternative docker im | LLM Quantization | [View Support Matrix](./examples/llm_ptq/README.md#support-matrix) | | Diffusers Quantization | [View Support Matrix](./examples/diffusers/README.md#support-matrix) | | VLM Quantization | [View Support Matrix](./examples/vlm_ptq/README.md#support-matrix) | -| ONNX Quantization | [View Support Matrix](./examples/onnx_ptq/README.md#onnx-export-supported-llm-models) | +| ONNX Quantization | [View Support Matrix](./examples/torch_onnx/README.md#onnx-export-supported-llm-models) | | Windows Quantization | [View Support Matrix](./examples/windows/README.md#support-matrix) | | Quantization Aware Training | [View Support Matrix](./examples/llm_qat/README.md#support-matrix) | | Pruning | [View Support Matrix](./examples/pruning/README.md#support-matrix) | diff --git a/docs/source/index.rst b/docs/source/index.rst index 5f08de7bf..f7b4cef4c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,6 +9,7 @@ Welcome to Model Optimizer (ModelOpt) documentation! getting_started/[0-9]* Quick Start: PTQ - PyTorch Quick Start: PTQ - ONNX + Quick Start: PTQ - PyTorch to ONNX Quick Start: PTQ - Windows Quick Start: QAT Quick Start: Pruning diff --git a/examples/onnx_ptq/README.md b/examples/onnx_ptq/README.md index 8d1416aa1..980a26493 100644 --- a/examples/onnx_ptq/README.md +++ b/examples/onnx_ptq/README.md @@ -12,10 +12,8 @@ Model Optimizer enables highly performant quantization formats including NVFP4, | :------------: | :------------: | :------------: | :------------: | | Pre-Requisites | Required & optional packages to use this technique | [Link](#pre-requisites) | | | Getting Started | Learn how to optimize your models using PTQ to reduce precision and improve inference efficiency | [Link](#getting-started) | [docs](https://nvidia.github.io/Model-Optimizer/guides/_onnx_quantization.html) | -| Support Matrix | View the ONNX export supported LLM models | [Link](#onnx-export-supported-llm-models) | | -| PyTorch to ONNX | Example scripts demonstrating how to quantize with PyTorch and then convert to ONNX | [Link](#torch-quantization-to-onnx-export-example) | | +| PyTorch to ONNX | Example scripts demonstrating how to quantize with PyTorch and then convert to ONNX | [Link](../torch_onnx/) | | | Advanced Features | Examples demonstrating use advanced ONNX quantization features | [Link](#advanced-features) | | -| Pre-Quantized Checkpoints | Ready to deploy Hugging Face pre-quantized checkpoints | [Link](#pre-quantized-checkpoints) | | | Resources | Extra links to relevant resources | [Link](#resources) | | @@ -80,7 +78,7 @@ python image_prep.py \ The model can be quantized as an FP8, INT8 or INT4 model using either the CLI or Python API. For FP8 and INT8 quantization, you have a choice between `max` and `entropy` calibration algorithms. For INT4 quantization, [awq_clip](https://arxiv.org/abs/2306.00978) or [rtn_dq](https://ar5iv.labs.arxiv.org/html/2301.12017) algorithms can be chosen. -> *For NVFP4 and MXFP8 ONNX, see the [PyTorch to ONNX section](#torch-quantization-to-onnx-export-example).* +> *For NVFP4 and MXFP8 ONNX, see the [PyTorch to ONNX example](../torch_onnx/).* > *Minimum opset requirements: int8 (13+), fp8 (21+), int4 (21+). ModelOpt will automatically upgrade lower opset versions to meet these requirements.* @@ -129,58 +127,6 @@ The top5 accuracy of the model is Inference latency of the model is ms ``` -## Torch quantization to ONNX export example - -This example demonstrates how to quantize a [timm](https://github.com/huggingface/pytorch-image-models) vision model for various precision formats followed by export to ONNX. The script leverages the ModelOpt toolkit for both quantization and ONNX export. - -> *Opset 20 is used to export the torch models to ONNX.* - -### What it does - -- Loads a pretrained timm torch model (default: ViT-Base). -- Quantizes the torch model to MXFP8, INT4 or NVFP4 using ModelOpt. -- Exports the quantized model to ONNX. -- Postprocesses the ONNX model to be compatible with TensorRT. -- Saves the final ONNX model. - -### Usage - -```bash -python torch_quant_to_onnx.py \ - --timm_model_name=vit_base_patch16_224 \ - --quantize_mode= \ - --onnx_save_path= -``` - -### Evaluation - -If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See for details. - -> *Note: TensorRT 10.11 or later is required to evaluate the MXFP8 or NVFP4 ONNX models.* - -```bash -python evaluate.py \ - --onnx_path= \ - --imagenet_path= \ - --engine_precision=stronglyTyped \ - --model_name=vit_base_patch16_224 -``` - -### ONNX Export Supported LLM Models - -| Model | FP16 | INT4 | FP8 | NVFP4 | -| :---: | :---: | :---: | :---: | :---: | -| [Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | ✅ | ✅ | ✅ | ✅ | -| [Llama3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) | ✅ | ✅ | ✅ | ✅ | -| [Llama3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B) | ✅ | ✅ | ✅ | ✅ | -| [Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) | ✅ | ✅ | ✅ | ✅ | -| [Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | ✅ | ✅ | ✅ | ✅ | -| [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) | ✅ | ✅ | ✅ | ✅ | -| [Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) | ✅ | ✅ | ✅ | ✅ | -| [Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | ✅ | ✅ | ✅ | ✅ | -| [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) | ✅ | ✅ | ✅ | ✅ | -| [Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) | ✅ | ✅ | ✅ | ✅ | - ## Advanced Features ### Per node calibration of ONNX models @@ -273,10 +219,6 @@ trtexec --onnx=/path/to/identity_neural_network.quant.onnx \ --staticPlugins=/path/to/libidentity_conv_iplugin_v2_io_ext.so ``` -## Pre-Quantized Checkpoints - -- Ready-to-deploy checkpoints that can be exported to ONNX format (if supported as per the [Support Matrix](#onnx-export-supported-llm-models)) \[[🤗 Hugging Face - Nvidia Model Optimizer Collection](https://huggingface.co/collections/nvidia/inference-optimized-checkpoints-with-model-optimizer)\] - ## Resources - 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) diff --git a/examples/torch_onnx/README.md b/examples/torch_onnx/README.md new file mode 100644 index 000000000..8c28f0729 --- /dev/null +++ b/examples/torch_onnx/README.md @@ -0,0 +1,215 @@ +# Torch Quantization to ONNX Export + +This example demonstrates how to quantize PyTorch models (vision and LLM) followed by export to ONNX format. The scripts leverage the ModelOpt toolkit for both quantization and ONNX export. + +
+ +| **Section** | **Description** | **Link** | +| :------------: | :------------: | :------------: | +| Pre-Requisites | Required packages to use this example | [Link](#pre-requisites) | +| Vision Models | Quantize timm models and export to ONNX | [Link](#vision-models) | +| LLM Export | Export LLMs to quantized ONNX | [Link](#llm-export) | +| Mixed Precision | Auto mode for optimal per-layer quantization | [Link](#mixed-precision-quantization-auto-mode) | +| Support Matrix | View the ONNX export supported LLM models | [Link](#onnx-export-supported-llm-models) | +| Resources | Extra links to relevant resources | [Link](#resources) | + +
+ +## Pre-Requisites + +### Docker + +Please use the TensorRT docker image (e.g., `nvcr.io/nvidia/tensorrt:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/Model-Optimizer/getting_started/2_installation.html) for more information. + +Set the following environment variables inside the TensorRT docker. + +```bash +export CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/ +export LD_LIBRARY_PATH="${CUDNN_LIB_DIR}:${LD_LIBRARY_PATH}" +``` + +### Local Installation + +Install Model Optimizer with `onnx` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example: + +```bash +pip install -U "nvidia-modelopt[onnx]" +pip install -r requirements.txt +``` + +For TensorRT Compiler framework workloads: + +Install the latest [TensorRT](https://developer.nvidia.com/tensorrt) from [here](https://developer.nvidia.com/tensorrt/download). + +## Vision Models + +The `torch_quant_to_onnx.py` script quantizes [timm](https://github.com/huggingface/pytorch-image-models) vision models and exports them to ONNX. + +### What it does + +- Loads a pretrained timm torch model (default: ViT-Base). +- Quantizes the torch model to FP8, MXFP8, INT8, NVFP4, or INT4_AWQ using ModelOpt. +- Exports the quantized model to ONNX. +- Postprocesses the ONNX model to be compatible with TensorRT. +- Saves the final ONNX model. + +> *Opset 20 is used to export the torch models to ONNX.* + +### Usage + +```bash +python torch_quant_to_onnx.py \ + --timm_model_name=vit_base_patch16_224 \ + --quantize_mode= \ + --onnx_save_path= +``` + +### Evaluation + +If the input model is of type image classification, use the following script to evaluate it. The script automatically downloads and uses the [ILSVRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset from Hugging Face. This gated repository requires authentication via Hugging Face access token. See for details. + +> *Note: TensorRT 10.11 or later is required to evaluate the MXFP8 or NVFP4 ONNX models.* + +```bash +python ../onnx_ptq/evaluate.py \ + --onnx_path= \ + --imagenet_path= \ + --engine_precision=stronglyTyped \ + --model_name=vit_base_patch16_224 +``` + +## LLM Export + +The `llm_export.py` script exports LLM models to ONNX with optional quantization. + +### What it does + +- Loads a HuggingFace LLM model (local path or model name). +- Optionally quantizes the model to FP8, INT4_AWQ, or NVFP4. +- Exports the model to ONNX format. +- Post-processes the ONNX graph for TensorRT compatibility. + +### Usage + +```bash +python llm_export.py \ + --hf_model_path= \ + --dtype= \ + --output_dir= +``` + +### Examples + +Export Qwen2 to FP16 ONNX: + +```bash +python llm_export.py \ + --hf_model_path=Qwen/Qwen2-0.5B-Instruct \ + --dtype=fp16 \ + --output_dir=./qwen2_fp16 +``` + +Export Qwen2 to FP8 ONNX with quantization: + +```bash +python llm_export.py \ + --hf_model_path=Qwen/Qwen2-0.5B-Instruct \ + --dtype=fp8 \ + --output_dir=./qwen2_fp8 +``` + +Export to NVFP4 with custom calibration: + +```bash +python llm_export.py \ + --hf_model_path=Qwen/Qwen3-0.6B \ + --dtype=nvfp4 \ + --calib_size=512 \ + --output_dir=./qwen3_nvfp4 +``` + +### Key Parameters + +| Parameter | Description | +| :--- | :--- | +| `--hf_model_path` | HuggingFace model name (e.g., `Qwen/Qwen2-0.5B-Instruct`) or local model path | +| `--dtype` | Export precision: `fp16`, `fp8`, `int4_awq`, or `nvfp4` | +| `--output_dir` | Directory to save the exported ONNX model | +| `--calib_size` | Number of calibration samples for quantization (default: 512) | +| `--lm_head` | Precision of lm_head layer (default: `fp16`) | +| `--save_original` | Save the raw ONNX before post-processing | +| `--trust_remote_code` | Trust remote code when loading from HuggingFace Hub | + +## Mixed Precision Quantization (Auto Mode) + +The `auto` mode enables mixed precision quantization by searching for the optimal quantization format per layer. This approach balances model accuracy and compression by assigning different precision formats (e.g., NVFP4, FP8) to different layers based on their sensitivity. + +### How it works + +1. **Sensitivity Analysis**: Computes per-layer sensitivity scores using gradient-based analysis +2. **Format Search**: Searches across specified quantization formats for each layer +3. **Constraint Optimization**: Finds the optimal format assignment that satisfies the effective bits constraint while minimizing accuracy loss + +### Key Parameters + +| Parameter | Default | Description | +| :--- | :---: | :--- | +| `--effective_bits` | 4.8 | Target average bits per weight across the model. Lower values = more compression but potentially lower accuracy. The search algorithm finds the optimal per-layer format assignment that meets this constraint while minimizing accuracy loss. For example, 4.8 means an average of 4.8 bits per weight (mix of FP4 and FP8 layers). | +| `--num_score_steps` | 128 | Number of forward/backward passes used to compute per-layer sensitivity scores via gradient-based analysis. Higher values provide more accurate sensitivity estimates but increase search time. Recommended range: 64-256. | +| `--calibration_data_size` | 512 | Number of calibration samples used for both sensitivity scoring and calibration. For auto mode, labels are required for loss computation. | + +### Usage + +```bash +python torch_quant_to_onnx.py \ + --timm_model_name=vit_base_patch16_224 \ + --quantize_mode=auto \ + --auto_quantization_formats NVFP4_AWQ_LITE_CFG FP8_DEFAULT_CFG \ + --effective_bits=4.8 \ + --num_score_steps=128 \ + --calibration_data_size=512 \ + --evaluate \ + --onnx_save_path=vit_base_patch16_224.auto_quant.onnx +``` + +### Results (ViT-Base) + +| | Top-1 accuracy (torch) | Top-5 accuracy (torch) | +| :--- | :---: | :---: | +| Torch autocast (FP16) | 85.11% | 97.53% | +| NVFP4 Quantized | 84.558% | 97.36% | +| Auto Quantized (FP8 + NVFP4, 4.78 effective bits) | 84.726% | 97.434% | + +## ONNX Export Supported LLM Models + +| Model | FP16 | INT4 | FP8 | NVFP4 | +| :---: | :---: | :---: | :---: | :---: | +| [Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | ✅ | ✅ | ✅ | ✅ | +| [Llama3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) | ✅ | ✅ | ✅ | ✅ | +| [Llama3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B) | ✅ | ✅ | ✅ | ✅ | +| [Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) | ✅ | ✅ | ✅ | ✅ | +| [Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | ✅ | ✅ | ✅ | ✅ | +| [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) | ✅ | ✅ | ✅ | ✅ | +| [Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) | ✅ | ✅ | ✅ | ✅ | +| [Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | ✅ | ✅ | ✅ | ✅ | +| [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) | ✅ | ✅ | ✅ | ✅ | +| [Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) | ✅ | ✅ | ✅ | ✅ | + +## Resources + +- 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) +- 📖 [Documentation](https://nvidia.github.io/Model-Optimizer) +- 🎯 [Benchmarks](../benchmark.md) +- 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) +- 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md) +- ✨ [File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md) + +### Technical Resources + +There are many quantization schemes supported in the example scripts: + +1. The [FP8 format](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) is available on the Hopper and Ada GPUs with [CUDA compute capability](https://developer.nvidia.com/cuda-gpus) greater than or equal to 8.9. + +1. The [INT4 AWQ](https://arxiv.org/abs/2306.00978) is an INT4 weight only quantization and calibration method. INT4 AWQ is particularly effective for low batch inference where inference latency is dominated by weight loading time rather than the computation time itself. For low batch inference, INT4 AWQ could give lower latency than FP8/INT8 and lower accuracy degradation than INT8. + +1. The [NVFP4](https://blogs.nvidia.com/blog/generative-ai-studio-ces-geforce-rtx-50-series/) is one of the new FP4 formats supported by NVIDIA Blackwell GPU and demonstrates good accuracy compared with other 4-bit alternatives. NVFP4 can be applied to both model weights as well as activations, providing the potential for both a significant increase in math throughput and reductions in memory footprint and memory bandwidth usage compared to the FP8 data format on Blackwell. diff --git a/examples/onnx_ptq/llm_export.py b/examples/torch_onnx/llm_export.py similarity index 100% rename from examples/onnx_ptq/llm_export.py rename to examples/torch_onnx/llm_export.py diff --git a/examples/torch_onnx/requirements.txt b/examples/torch_onnx/requirements.txt new file mode 100644 index 000000000..a1f2dba33 --- /dev/null +++ b/examples/torch_onnx/requirements.txt @@ -0,0 +1,4 @@ +datasets>=2.14.4 +timm +torchvision +transformers diff --git a/examples/onnx_ptq/torch_quant_to_onnx.py b/examples/torch_onnx/torch_quant_to_onnx.py similarity index 98% rename from examples/onnx_ptq/torch_quant_to_onnx.py rename to examples/torch_onnx/torch_quant_to_onnx.py index 06f1b1db8..05f52aa4f 100644 --- a/examples/onnx_ptq/torch_quant_to_onnx.py +++ b/examples/torch_onnx/torch_quant_to_onnx.py @@ -15,6 +15,11 @@ import argparse import re +import sys +from pathlib import Path + +# Add onnx_ptq to path for shared modules +sys.path.insert(0, str(Path(__file__).parent.parent / "onnx_ptq")) import timm import torch @@ -323,12 +328,6 @@ def main(): ) print(f"Quantized Model - Top-1 Accuracy: {top1:.2f}%, Top-5 Accuracy: {top5:.2f}%") - if args.quantize_mode in ["auto"]: - print( - f"The selected quantization mode {args.quantize_mode} is not supported for ONNX export yet." - ) - return - # Export to ONNX export_to_onnx( quantized_model, diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index f66a59506..28e6b1da1 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -15,27 +15,91 @@ """FP8 quantization exporter.""" +import time + import onnx +import onnx_graphsurgeon as gs +import torch +from onnx_graphsurgeon.ir.tensor import LazyValues from .base_exporter import ONNXQuantExporter -# TODO: Implement the FP8QuantExporter class FP8QuantExporter(ONNXQuantExporter): """Exporter for FP8 quantization.""" @staticmethod def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Pre-processes the ONNX model for FP8 quantization.""" + return onnx_model @staticmethod def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Computes the scales for the weights in the ONNX model for FP8 quantization.""" + return onnx_model @staticmethod def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Compresses the weights in the ONNX model for FP8 quantization.""" + """Compresses FP32/FP16 weights to FP8 by folding QDQ nodes to DQ only. + + Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ. + The storage is therefore very bad. In this function, + Q nodes will get removed from the weights and have only DQ nodes with those converted FP8 + weights in the output model. + + Parameters: + onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes. + + Returns: + ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations). + """ + start_time = time.time() + print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...") + + graph = gs.import_onnx(onnx_model) + # Fold constants is required since the scale is not constant yet. + graph.cleanup().toposort().fold_constants().cleanup() + + for node in graph.nodes: + if node.op == "TRT_FP8QuantizeLinear": + # Should not remove input QDQ + if not isinstance(node.inputs[0], gs.Constant): + continue + + weights = node.inputs[0] + scale = node.inputs[1] + torch_weights = torch.from_numpy(weights.values) + torch_scale = torch.from_numpy(scale.values) + quantizer_name = scale.name.rsplit("/", 1)[0] + dq_op = node.outputs[0].outputs[0] + assert dq_op.op == "TRT_FP8DequantizeLinear", ( + f"QDQ does not occur in pairs. You reached {dq_op.op}" + ) + + # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. + numpy_weights = ( + (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() + ) + tensor = onnx.TensorProto() + tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + tensor.dims.extend(numpy_weights.shape) + tensor.raw_data = numpy_weights.tobytes() + values = LazyValues(tensor) + onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) + + node.outputs.clear() + # DQ Op is separated out + dq_op.inputs[0] = onnx_weights_fp8 + dq_op.op = "DequantizeLinear" + dq_op.outputs[0].dtype = dq_op.inputs[1].dtype + + graph.cleanup().toposort() + end_time = time.time() + print(f"fp8 qdq replaced with only dq completed in {end_time - start_time}s.") + + return gs.export_onnx(graph) @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Post-processes the ONNX model for FP8 quantization.""" + return onnx_model diff --git a/modelopt/onnx/export/int8_exporter.py b/modelopt/onnx/export/int8_exporter.py index 1ec2e707f..4623279b5 100644 --- a/modelopt/onnx/export/int8_exporter.py +++ b/modelopt/onnx/export/int8_exporter.py @@ -27,15 +27,19 @@ class INT8QuantExporter(ONNXQuantExporter): @staticmethod def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Pre-processes the ONNX model for INT8 quantization.""" + return onnx_model @staticmethod def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Computes the scales for the weights in the ONNX model for INT8 quantization.""" + return onnx_model @staticmethod def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Compresses the weights in the ONNX model for INT8 quantization.""" + return onnx_model @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Post-processes the ONNX model for INT8 quantization.""" + return onnx_model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index ba1c6f56b..26a5781ed 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -345,6 +345,8 @@ def is_int8_quantized(model: nn.Module) -> bool: if ( hasattr(module, "weight_quantizer") and hasattr(module, "input_quantizer") + and module.weight_quantizer.is_enabled + and module.input_quantizer.is_enabled and module.weight_quantizer._num_bits == 8 and module.input_quantizer._num_bits == 8 ): @@ -358,6 +360,8 @@ def is_fp8_quantized(model: nn.Module) -> bool: if ( hasattr(module, "weight_quantizer") and hasattr(module, "input_quantizer") + and module.weight_quantizer.is_enabled + and module.input_quantizer.is_enabled and module.weight_quantizer._num_bits == (4, 3) and module.input_quantizer._num_bits == (4, 3) # Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits diff --git a/tests/examples/onnx_ptq/test_llm_export.py b/tests/examples/torch_onnx/test_llm_export.py similarity index 96% rename from tests/examples/onnx_ptq/test_llm_export.py rename to tests/examples/torch_onnx/test_llm_export.py index 6e5d170c7..eef19d1b2 100644 --- a/tests/examples/onnx_ptq/test_llm_export.py +++ b/tests/examples/torch_onnx/test_llm_export.py @@ -36,4 +36,4 @@ def test_llm_export_onnx(tmp_path, hf_model_path, dtype, lm_head): output_dir=str(tmp_path), calib_size=1, ) - run_example_command(cmd_parts, "onnx_ptq") + run_example_command(cmd_parts, "torch_onnx") diff --git a/tests/examples/onnx_ptq/test_torch_quant_to_onnx.py b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py similarity index 63% rename from tests/examples/onnx_ptq/test_torch_quant_to_onnx.py rename to tests/examples/torch_onnx/test_torch_quant_to_onnx.py index 006236c2c..2e7bf58c3 100644 --- a/tests/examples/onnx_ptq/test_torch_quant_to_onnx.py +++ b/tests/examples/torch_onnx/test_torch_quant_to_onnx.py @@ -20,18 +20,22 @@ # TODO: Add accuracy evaluation after we upgrade TRT version to 10.12 @pytest.mark.parametrize( - ("quantize_mode", "onnx_save_path", "calib_size"), + ("quantize_mode", "onnx_save_path", "calib_size", "num_score_steps"), [ - ("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1"), - ("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1"), - ("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1"), + ("fp8", "vit_base_patch16_224.fp8.onnx", "1", "1"), + ("int8", "vit_base_patch16_224.int8.onnx", "1", "1"), + ("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1", "1"), + ("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1", "1"), + ("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1", "1"), + ("auto", "vit_base_patch16_224.auto.onnx", "1", "1"), ], ) -def test_torch_onnx(quantize_mode, onnx_save_path, calib_size): +def test_torch_onnx(quantize_mode, onnx_save_path, calib_size, num_score_steps): cmd_parts = extend_cmd_parts( ["python", "torch_quant_to_onnx.py"], quantize_mode=quantize_mode, onnx_save_path=onnx_save_path, calibration_data_size=calib_size, + num_score_steps=num_score_steps, ) - run_example_command(cmd_parts, "onnx_ptq") + run_example_command(cmd_parts, "torch_onnx")